PPO nan in actor logits

Sorry for the late reply @hermmanhender but yes as @mannyv stated you either have to add the log standard standard deviations as parameters of the model, or you can clamp them from something like -1 to 1. I have seen having them parameterized allows for less exploration and when I fight them against each other in custom multi-agent scenarios, the clamped version wins more often. Something like this is the parameterized version:

        self.critic_fcnet = TorchFC(obs_space, action_space, 1, model_config, name + "_critic")
        self.actor_means = TorchFC(obs_space, action_space, action_space.shape[0], model_config, name + 
                                   "_actor")
        self.log_std_init = model_config['custom_model_config'].get('log_std_init', 0)
        self.log_stds = nn.Parameter(torch.ones(action_space.shape[0]) * self.log_std_init, requires_grad = True)

    def forward(self, input_dict, state, seq_lens):
        means, _ = self.actor_means(input_dict, state, seq_lens)
        log_stds = self.log_stds.expand_as(means)
        logits = torch.cat((means, log_stds), dim = -1)
        self.value, _ = self.critic_fcnet(input_dict, state, seq_lens)

else you can do the clamped version:

 self.actor_logits = TorchFC(obs_space, action_space, action_space.shape[0]*2, model_config, name + 
                                   "_actor")

    def forward(self, input_dict, state, seq_lens):
        # Get the model output
        logits, _ = self.actor_logits (input_dict, state, seq_lens)
        means, log_stds = torch.chunk(logits, 2, -1)
        log_stds = torch.clamp(log_stds , -1, 1)
        logits = torch.cat((means, log_stds ), dim = -1)
        self.value, _ = self.critic_fcnet(input_dict, state, seq_lens)
        self.log_step += 1

Either of these seems to have helped my issue and I no longer run into it. From lots of logging it seems as though the logp_ratio goes to nan due to the std_dev going extremely small.

1 Like