Soft actor critic KL divergence policy loss blowup

Hi guys,

This is not directly related to rllib but I thought this community could help with this question.

I am experiencing kl loss blow up using soft actor critic. I am doing inverse reinforcement learning on a custom point mass environment. I use SAC to compute the MaxEnt Q function and policy.

As I understand, my Q value and policy losses are computed as:

# frist sample state, action, next_state, reward from buffer
next_action = policy_network(next_state)
Q_next = q_network(next_state, next_action)    
Q_now = q_network(state, action)

# td loss
Q_target = reward - logp + Q_next    
Q_loss = F.mse_loss(Q_now, Q_target.data)

# kl divergence loss
pi_loss = torch.mean(logps - Q_next.data)

As said in the paper, I use 1e-4 learning rate for both the q_network and policy_network. I did not apply tanh transformation since I have demonstrations to guide my action range.

Here is what my loss curves look like. traj_loss is the action log-likelihood of my observed trajectories. I am operating in a partially observable environment and I use a variational autoencoder to update the agent’s beliefs. So infer_loss is the negative elbo. q_loss and pi_loss are the above.

One thing about this loss curve which is very different from many other loss curves is that pi_loss heads straight down after a few iterations, and that’s where q_loss also starts to blowup.
截屏2021-05-27 下午4.05.55

I actually tried 2 versions of the loss function. In the first version, I follow SAC and compute q_target as:

 Q_target = reward - logp + Q_next

In the second version I don’t maximize entropy, so q_target is computed as:

Q_target = reward + Q_next

With the first version, q_loss simply never improves (this is in the figure although you can’t see that). With the second version (no max_ent), q_loss improves for a few iterations until the policy heads straight down and itself also blowup, which result in the same loss curve so I didn’t include.

I think this loss curve behavior make sense since the gaussian policy can get infinitely precise which results in the blow up. I don’t think applying tanh will lead to a different behavior since it is applied to the mean not the variance. I also wondered whether the original paper and many implementations only applied a lower bound to variance, however I did not find it in any implementations. There is also no discussion about this issue anywhere.

Think it is a matter of hyperparameter tuning. Could you try Pendulum and see if the same phenomenon happens? It should learn in < 7k timesteps.

1 Like

You need to add an epsilon to the denominator for numerical stability when you take the kl divergence, otherwise you will divide by nearly zero and take -log 0.00001 and your loss will blow up.

1 Like

Thanks for the suggestion I will try that.

Hi, I don’t see a denominator in the equations. Do you mean adding an action variance lower bound like log(pi + 1e-5)? That’s a good suggestion, would this make the loss curve convex?