As I understand, RLlib computes all the losses at once and applies the gradient from the accumulated loss (e.g., critic + actor loss). I wonder if I could separate this process into multiple-gradient updates. For example, calculating the critic loss then apply the gradient "before" calculating the actor loss.
For tf, yes, you can specify a custom gradient function for your policy (but you would have to re-“build” a new policy class via
Then specify the
For torch, I don’t think we unified this behavior. You can only modify once calculated gradients, but the loss+grad-calc+grad-apply order is fixed. We should flexibilize it here as well, though.
Thank you for your reply
I went with overriding TorchPolicy though just like what QMIX code does since I prefer torch. Basically, overriding the
learn_on_batch function and control the learning flow myself.