Independent gradient update for each loss

As I understand, RLlib computes all the losses at once and applies the gradient from the accumulated loss (e.g., critic + actor loss). I wonder if I could separate this process into multiple-gradient updates. For example, calculating the critic loss then apply the gradient "before" calculating the actor loss.

For tf, yes, you can specify a custom gradient function for your policy (but you would have to re-“build” a new policy class via rllib/policy/tf_policy_template/build_tf_policy).
Then specify the gradients_fn and/or apply_gradients_fn.

For torch, I don’t think we unified this behavior. You can only modify once calculated gradients, but the loss+grad-calc+grad-apply order is fixed. We should flexibilize it here as well, though.

Thank you for your reply :slight_smile:
I went with overriding TorchPolicy though just like what QMIX code does since I prefer torch. Basically, overriding the learn_on_batch function and control the learning flow myself.

1 Like