Hi guys,
This is not directly related to rllib but I thought this community could help with this question.
I am experiencing kl loss blow up using soft actor critic. I am doing inverse reinforcement learning on a custom point mass environment. I use SAC to compute the MaxEnt Q function and policy.
As I understand, my Q value and policy losses are computed as:
# frist sample state, action, next_state, reward from buffer
next_action = policy_network(next_state)
Q_next = q_network(next_state, next_action)
Q_now = q_network(state, action)
# td loss
Q_target = reward - logp + Q_next
Q_loss = F.mse_loss(Q_now, Q_target.data)
# kl divergence loss
pi_loss = torch.mean(logps - Q_next.data)
As said in the paper, I use 1e-4 learning rate for both the q_network and policy_network. I did not apply tanh transformation since I have demonstrations to guide my action range.
Here is what my loss curves look like. traj_loss is the action log-likelihood of my observed trajectories. I am operating in a partially observable environment and I use a variational autoencoder to update the agent’s beliefs. So infer_loss is the negative elbo. q_loss and pi_loss are the above.
One thing about this loss curve which is very different from many other loss curves is that pi_loss heads straight down after a few iterations, and that’s where q_loss also starts to blowup.
I actually tried 2 versions of the loss function. In the first version, I follow SAC and compute q_target as:
Q_target = reward - logp + Q_next
In the second version I don’t maximize entropy, so q_target is computed as:
Q_target = reward + Q_next
With the first version, q_loss simply never improves (this is in the figure although you can’t see that). With the second version (no max_ent), q_loss improves for a few iterations until the policy heads straight down and itself also blowup, which result in the same loss curve so I didn’t include.
I think this loss curve behavior make sense since the gaussian policy can get infinitely precise which results in the blow up. I don’t think applying tanh will lead to a different behavior since it is applied to the mean not the variance. I also wondered whether the original paper and many implementations only applied a lower bound to variance, however I did not find it in any implementations. There is also no discussion about this issue anywhere.