PPO Training Error: NaN Values in Gradients and Near-Zero Loss

Hey @inigo_Gastesi this is a problem that I have ran into as well. You can see it here: PPO nan actor logits. With this being said, you can add the log std deviations as parameters of the model or clamp. I will have a sample below (obviously change to your liking and this is just a small snip).

quick edit: this could be the potential problem. I have also encountered other problems with reward function issues, incorrect signs of the losses, etc. However by seeing the action distribution is going nan this is my first hunch.


class SimpleCustomTorchModel(TorchModelV2, nn.Module):
    def __init__(self, obs_space, action_space, num_outputs, model_config, name):
        TorchModelV2.__init__(self, obs_space, action_space, num_outputs, model_config, name)

        self.critic_fcnet = TorchFC(obs_space, action_space, 1, model_config, name + "_critic")
        self.actor_means = TorchFC(obs_space, action_space, action_space.shape[0], model_config, name + 
        self.log_std_init = model_config['custom_model_config'].get('log_std_init', 0)
        self.log_stds = nn.Parameter(torch.ones(action_space.shape[0]) * self.log_std_init, requires_grad = True)

    def forward(self, input_dict, state, seq_lens):
        # Get the model output
        means, _ = self.actor_means(input_dict, state, seq_lens)
        log_stds = self.log_stds.expand_as(means)
        logits = torch.cat((means, log_stds), dim = -1)
        self.value, _ = self.critic_fcnet(input_dict, state, seq_lens)
        return logits, state

or clamp them like this:

class SimpleCustomTorchModel(TorchModelV2, nn.Module):
    def __init__(self, obs_space, action_space, num_outputs, model_config, name):
        TorchModelV2.__init__(self, obs_space, action_space, num_outputs, model_config, name)

        self.critic_fcnet = TorchFC(obs_space, action_space, 1, model_config, name + "_critic")
        self.actor_logits = TorchFC(obs_space, action_space, action_space.shape[0]*2, model_config, name + 

    def forward(self, input_dict, state, seq_lens):
        logits_unclamped, _ = self.actor_logits(input_dict, state, seq_lens)
        means, log_stds = torch.chunk(logits_unclamped, 2, -1)
        clamped_log_stds = torch.clamp(log_stds, -1, 1)
        logits = torch.cat((means, clamped_log_stds), -1)
        self.value, _ = self.critic_fcnet(input_dict, state, seq_lens)
        return logits, state

I would try the parameterized first and if you are unhappy with the performance (due to lack of exploration) then try the clamped version.

All the best,
