As I understand, RLlib computes all the losses at once and applies the gradient from the accumulated loss (e.g., critic + actor loss). I wonder if I could separate this process into multiple-gradient updates. For example, calculating the critic loss then apply the gradient "before" calculating the actor loss.
For tf, yes, you can specify a custom gradient function for your policy (but you would have to re-“build” a new policy class via rllib/policy/tf_policy_template/build_tf_policy
).
Then specify the gradients_fn
and/or apply_gradients_fn
.
For torch, I don’t think we unified this behavior. You can only modify once calculated gradients, but the loss+grad-calc+grad-apply order is fixed. We should flexibilize it here as well, though.
Thank you for your reply
I went with overriding TorchPolicy though just like what QMIX code does since I prefer torch. Basically, overriding the learn_on_batch
function and control the learning flow myself.
1 Like