Hey @Samuel_Fipps I was seeing that you were looking at the clamping of the log_stds like @mannyv has been pointing to for, what seems like, years now lol.
So, the actual clamping range of the log_stds depending on your action space. In the example that was given above from the PPO NAN logits, I unknowingly clamped the log_stds from (-1.0, 1.0), but instead I will give you a break down here of what it should be (if this is your problem).
if you have your means from the policy logits normalized between (-1, 1) then your log_stds should be
logits, _ = self.actor_fcnet(input_dict, state, seq_lens)
means, log_stds = torch.chunk(logits, 2, -1)
# assuming means are normalized between -1 and 1
# even if normalized, I have noticed sometimes my actions in the batch are outside these bounds
means_clamped = torch.clamp(means, -1, 1)
# this is based on the means being -1 to 1 so the std_dev domain would be [0,1)
# where exp(-10) and exp(0) would give the above domain for std_dev
log_stds_clamped = torch.clamp(log_stds, -10, 0)
logits = torch.cat((means_clamped, log_stds_clamped), dim = -1)
self.value, _ = self.critic_fcnet(input_dict, state, seq_lens)
So if the mean of your normalized action space is 0, then your standard deviation’s min. can be near 0 (but not 0) and your standard deviation’s max. can be 1. So, we need to plug values into exp(value) to get near both 0 and 1. Now we currently have log standard deviations and we want just standard deviations, so we would like to take the exp() to get near 0 and 1.
So, what I have concluded is that I can use the following:
- to get near 0, we can set the lower bound as exp(-10) which is ~ 0.00004539992
- to get 1, we can set the upper bound as exp(0) which is 1
Therefore, our clamped log_stds should be in reference to the action space where initially I did it arbitrarily between (-1, 1) like you have. This might work, but doesn’t seem to be correct.
@mannyv feel free to correct me if you think this is wrong.