PPO Training Error: NaN Values in Gradients and Near-Zero Loss

Hi everyone,

I’m trying to train a PPO model using RLlib, but I keep encountering the following error, which I assume is due to the gradients “dying”:

File "/opt/conda/lib/python3.9/site-packages/ray/rllib/algorithms/ppo/ppo_torch_policy.py", line 85, in loss
    curr_action_dist = dist_class(logits, model)
  File "/opt/conda/lib/python3.9/site-packages/ray/rllib/models/torch/torch_action_dist.py", line 512, in __init__
    self.flat_child_distributions = tree.map_structure(
  File "/opt/conda/lib/python3.9/site-packages/tree/__init__.py", line 435, in map_structure
    [func(*args) for args in zip(*map(flatten, structures))])
  File "/opt/conda/lib/python3.9/site-packages/tree/__init__.py", line 435, in <listcomp>
    [func(*args) for args in zip(*map(flatten, structures))])
  File "/opt/conda/lib/python3.9/site-packages/ray/rllib/models/torch/torch_action_dist.py", line 513, in <lambda>
    lambda dist, input_: dist(input_, model),
  File "/opt/conda/lib/python3.9/site-packages/ray/rllib/models/torch/torch_action_dist.py", line 250, in __init__
    self.dist = torch.distributions.normal.Normal(mean, torch.exp(log_std))
  File "/opt/conda/lib/python3.9/site-packages/torch/distributions/normal.py", line 56, in __init__
    super().__init__(batch_shape, validate_args=validate_args)
  File "/opt/conda/lib/python3.9/site-packages/torch/distributions/distribution.py", line 68, in __init__
    raise ValueError(
ValueError: Expected parameter loc (Tensor of shape (128, 2)) of distribution Normal(loc: torch.Size([128, 2]), scale: torch.Size([128, 2])) to satisfy the constraint Real(), but found invalid values:
tensor([[nan, nan],
        [nan, nan],
        ...
        [nan, nan]], grad_fn=<SplitBackward0>)

Additionally, I’ve noticed that the reward is very negative, but the loss quickly approaches zero, around 1e-5. The error occurs precisely when the policy loss reaches this point. Interestingly, switching the algorithm to APPO resolves the issue.

Can anyone explain what’s happening here? Why does this error occur with PPO but not with APPO?

Thanks in advance for your help!

Hey @inigo_Gastesi this is a problem that I have ran into as well. You can see it here: PPO nan actor logits. With this being said, you can add the log std deviations as parameters of the model or clamp. I will have a sample below (obviously change to your liking and this is just a small snip).

quick edit: this could be the potential problem. I have also encountered other problems with reward function issues, incorrect signs of the losses, etc. However by seeing the action distribution is going nan this is my first hunch.

Parameterized:

class SimpleCustomTorchModel(TorchModelV2, nn.Module):
    def __init__(self, obs_space, action_space, num_outputs, model_config, name):
        TorchModelV2.__init__(self, obs_space, action_space, num_outputs, model_config, name)
        nn.Module.__init__(self)

        self.critic_fcnet = TorchFC(obs_space, action_space, 1, model_config, name + "_critic")
        self.actor_means = TorchFC(obs_space, action_space, action_space.shape[0], model_config, name + 
                                   "_actor")
        self.log_std_init = model_config['custom_model_config'].get('log_std_init', 0)
        self.log_stds = nn.Parameter(torch.ones(action_space.shape[0]) * self.log_std_init, requires_grad = True)

    def forward(self, input_dict, state, seq_lens):
        # Get the model output
        means, _ = self.actor_means(input_dict, state, seq_lens)
        log_stds = self.log_stds.expand_as(means)
        logits = torch.cat((means, log_stds), dim = -1)
        self.value, _ = self.critic_fcnet(input_dict, state, seq_lens)
        
        return logits, state

or clamp them like this:

class SimpleCustomTorchModel(TorchModelV2, nn.Module):
    def __init__(self, obs_space, action_space, num_outputs, model_config, name):
        TorchModelV2.__init__(self, obs_space, action_space, num_outputs, model_config, name)
        nn.Module.__init__(self)

        self.critic_fcnet = TorchFC(obs_space, action_space, 1, model_config, name + "_critic")
        self.actor_logits = TorchFC(obs_space, action_space, action_space.shape[0]*2, model_config, name + 
                                   "_actor")

    def forward(self, input_dict, state, seq_lens):
        logits_unclamped, _ = self.actor_logits(input_dict, state, seq_lens)
        means, log_stds = torch.chunk(logits_unclamped, 2, -1)
        clamped_log_stds = torch.clamp(log_stds, -1, 1)
        logits = torch.cat((means, clamped_log_stds), -1)
        self.value, _ = self.critic_fcnet(input_dict, state, seq_lens)
        
        return logits, state

I would try the parameterized first and if you are unhappy with the performance (due to lack of exploration) then try the clamped version.

All the best,

Tyler

This domain should instead be the below. Sorry for the mixup – I was clamping my means from -1 to 1.

clamped_log_stds = torch.clamp(log_stds, -10, 0)

Hi @tlaurie99, thank you very much for your input. I couldn’t read it earlier because I was on vacation. If it happens again, I’ll try what you suggested. Thanks again for everything!

I’m trying to implement what you’ve suggested in RLlib. I understand that a custom model needs to be created, but I’m not entirely sure how to do it. You use TorchModelV2 to create the custom model, but this requires more components to make it work, right? It’s worth mentioning that I’m using tuner to train the model, so I’m passing trainable: PPO and then in the param_space, I’m specifying that it should use the custom_model: "SimpleCustomTorchModel". Of course, I registered the model beforehand. However, I’m currently getting the following error:

  File "/opt/conda/lib/python3.9/site-packages/ray/rllib/evaluation/rollout_worker.py", line 532, in __init__
    self._update_policy_map(policy_dict=self.policy_dict)
  File "/opt/conda/lib/python3.9/site-packages/ray/rllib/evaluation/rollout_worker.py", line 1737, in _update_policy_map
    self._build_policy_map(
  File "/opt/conda/lib/python3.9/site-packages/ray/rllib/evaluation/rollout_worker.py", line 1848, in _build_policy_map
    new_policy = create_policy_for_framework(
  File "/opt/conda/lib/python3.9/site-packages/ray/rllib/utils/policy.py", line 141, in create_policy_for_framework
    return policy_class(observation_space, action_space, merged_config)
  File "/opt/conda/lib/python3.9/site-packages/ray/rllib/algorithms/ppo/ppo_torch_policy.py", line 64, in __init__
    self._initialize_loss_from_dummy_batch()
  File "/opt/conda/lib/python3.9/site-packages/ray/rllib/policy/policy.py", line 1396, in _initialize_loss_from_dummy_batch
    actions, state_outs, extra_outs = self.compute_actions_from_input_dict(
  File "/opt/conda/lib/python3.9/site-packages/ray/rllib/policy/torch_policy_v2.py", line 560, in compute_actions_from_input_dict
    return self._compute_action_helper(
  File "/opt/conda/lib/python3.9/site-packages/ray/rllib/utils/threading.py", line 24, in wrapper
    return func(self, *a, **k)
  File "/opt/conda/lib/python3.9/site-packages/ray/rllib/policy/torch_policy_v2.py", line 1293, in _compute_action_helper
    extra_fetches = self.extra_action_out(
  File "/opt/conda/lib/python3.9/site-packages/ray/rllib/policy/torch_mixins.py", line 185, in extra_action_out
    SampleBatch.VF_PREDS: model.value_function(),
  File "/opt/conda/lib/python3.9/site-packages/ray/rllib/models/modelv2.py", line 150, in value_function
    raise NotImplementedError

I imagine this is because more things need to be defined, but right now, I’m not sure how to solve it.

@inigo_Gastesi : In fact, the ModelV2 class contains a lot of abstract methods, which are to be implemented in sub-classes. The hierarchy is as following

ModelV2
|
TorchModelV2
|
SimpleCustomTorchModel (or whatever class name you gave for your one)

The value_function() method must be implemented in your class. In the example files, most often it simply contains a call to the value_funtion of the Torch FullyConnectedNetwork, or whatever neural network class you are using.

1 Like

If you want a custom model within rllib then you should override the forward and value function methods of TorchModelV2 (if you inherited this one). You can always just call the inherited value function / forward methods if you don’t want to implement custom solutions. An example can be seen here. Let me know if you still encounter these issues. I did not post the whole model in my previous reply and was mainly showing the clamping of the log stds.

All the best! Also, thank you @PhilippWillms for the additional help clarifying.