Hey @inigo_Gastesi this is a problem that I have ran into as well. You can see it here: PPO nan actor logits. With this being said, you can add the log std deviations as parameters of the model or clamp. I will have a sample below (obviously change to your liking and this is just a small snip).
quick edit: this could be the potential problem. I have also encountered other problems with reward function issues, incorrect signs of the losses, etc. However by seeing the action distribution is going nan this is my first hunch.
Parameterized:
class SimpleCustomTorchModel(TorchModelV2, nn.Module):
def __init__(self, obs_space, action_space, num_outputs, model_config, name):
TorchModelV2.__init__(self, obs_space, action_space, num_outputs, model_config, name)
nn.Module.__init__(self)
self.critic_fcnet = TorchFC(obs_space, action_space, 1, model_config, name + "_critic")
self.actor_means = TorchFC(obs_space, action_space, action_space.shape[0], model_config, name +
"_actor")
self.log_std_init = model_config['custom_model_config'].get('log_std_init', 0)
self.log_stds = nn.Parameter(torch.ones(action_space.shape[0]) * self.log_std_init, requires_grad = True)
def forward(self, input_dict, state, seq_lens):
# Get the model output
means, _ = self.actor_means(input_dict, state, seq_lens)
log_stds = self.log_stds.expand_as(means)
logits = torch.cat((means, log_stds), dim = -1)
self.value, _ = self.critic_fcnet(input_dict, state, seq_lens)
return logits, state
or clamp them like this:
class SimpleCustomTorchModel(TorchModelV2, nn.Module):
def __init__(self, obs_space, action_space, num_outputs, model_config, name):
TorchModelV2.__init__(self, obs_space, action_space, num_outputs, model_config, name)
nn.Module.__init__(self)
self.critic_fcnet = TorchFC(obs_space, action_space, 1, model_config, name + "_critic")
self.actor_logits = TorchFC(obs_space, action_space, action_space.shape[0]*2, model_config, name +
"_actor")
def forward(self, input_dict, state, seq_lens):
logits_unclamped, _ = self.actor_logits(input_dict, state, seq_lens)
means, log_stds = torch.chunk(logits_unclamped, 2, -1)
clamped_log_stds = torch.clamp(log_stds, -1, 1)
logits = torch.cat((means, clamped_log_stds), -1)
self.value, _ = self.critic_fcnet(input_dict, state, seq_lens)
return logits, state
I would try the parameterized first and if you are unhappy with the performance (due to lack of exploration) then try the clamped version.
All the best,
Tyler