Sorry for the late reply @hermmanhender but yes as @mannyv stated you either have to add the log standard standard deviations as parameters of the model, or you can clamp them from something like -1 to 1. I have seen having them parameterized allows for less exploration and when I fight them against each other in custom multi-agent scenarios, the clamped version wins more often. Something like this is the parameterized version:
self.critic_fcnet = TorchFC(obs_space, action_space, 1, model_config, name + "_critic")
self.actor_means = TorchFC(obs_space, action_space, action_space.shape[0], model_config, name +
"_actor")
self.log_std_init = model_config['custom_model_config'].get('log_std_init', 0)
self.log_stds = nn.Parameter(torch.ones(action_space.shape[0]) * self.log_std_init, requires_grad = True)
def forward(self, input_dict, state, seq_lens):
means, _ = self.actor_means(input_dict, state, seq_lens)
log_stds = self.log_stds.expand_as(means)
logits = torch.cat((means, log_stds), dim = -1)
self.value, _ = self.critic_fcnet(input_dict, state, seq_lens)
else you can do the clamped version:
self.actor_logits = TorchFC(obs_space, action_space, action_space.shape[0]*2, model_config, name +
"_actor")
def forward(self, input_dict, state, seq_lens):
# Get the model output
logits, _ = self.actor_logits (input_dict, state, seq_lens)
means, log_stds = torch.chunk(logits, 2, -1)
log_stds = torch.clamp(log_stds , -1, 1)
logits = torch.cat((means, log_stds ), dim = -1)
self.value, _ = self.critic_fcnet(input_dict, state, seq_lens)
self.log_step += 1
Either of these seems to have helped my issue and I no longer run into it. From lots of logging it seems as though the logp_ratio goes to nan due to the std_dev going extremely small.