PPO Training Error: NaN Values in Gradients and Near-Zero Loss

Hey @inigo_Gastesi this is a problem that I have ran into as well. You can see it here: PPO nan actor logits. With this being said, you can add the log std deviations as parameters of the model or clamp. I will have a sample below (obviously change to your liking and this is just a small snip).

quick edit: this could be the potential problem. I have also encountered other problems with reward function issues, incorrect signs of the losses, etc. However by seeing the action distribution is going nan this is my first hunch.

Parameterized:

class SimpleCustomTorchModel(TorchModelV2, nn.Module):
    def __init__(self, obs_space, action_space, num_outputs, model_config, name):
        TorchModelV2.__init__(self, obs_space, action_space, num_outputs, model_config, name)
        nn.Module.__init__(self)

        self.critic_fcnet = TorchFC(obs_space, action_space, 1, model_config, name + "_critic")
        self.actor_means = TorchFC(obs_space, action_space, action_space.shape[0], model_config, name + 
                                   "_actor")
        self.log_std_init = model_config['custom_model_config'].get('log_std_init', 0)
        self.log_stds = nn.Parameter(torch.ones(action_space.shape[0]) * self.log_std_init, requires_grad = True)

    def forward(self, input_dict, state, seq_lens):
        # Get the model output
        means, _ = self.actor_means(input_dict, state, seq_lens)
        log_stds = self.log_stds.expand_as(means)
        logits = torch.cat((means, log_stds), dim = -1)
        self.value, _ = self.critic_fcnet(input_dict, state, seq_lens)
        
        return logits, state

or clamp them like this:

class SimpleCustomTorchModel(TorchModelV2, nn.Module):
    def __init__(self, obs_space, action_space, num_outputs, model_config, name):
        TorchModelV2.__init__(self, obs_space, action_space, num_outputs, model_config, name)
        nn.Module.__init__(self)

        self.critic_fcnet = TorchFC(obs_space, action_space, 1, model_config, name + "_critic")
        self.actor_logits = TorchFC(obs_space, action_space, action_space.shape[0]*2, model_config, name + 
                                   "_actor")

    def forward(self, input_dict, state, seq_lens):
        logits_unclamped, _ = self.actor_logits(input_dict, state, seq_lens)
        means, log_stds = torch.chunk(logits_unclamped, 2, -1)
        clamped_log_stds = torch.clamp(log_stds, -1, 1)
        logits = torch.cat((means, clamped_log_stds), -1)
        self.value, _ = self.critic_fcnet(input_dict, state, seq_lens)
        
        return logits, state

I would try the parameterized first and if you are unhappy with the performance (due to lack of exploration) then try the clamped version.

All the best,

Tyler