PPO nan in actor logits

How severe does this issue affect your experience of using Ray?
-High: custom models and custom policies I am testing for work do not run effectively.

I am using custom models and custom policies within the PyFlyt environment. Specifically using MAFixedwingDogfightEnv which I create a custom MultiAgentEnv with a reward wrapper as per the PyFlyt example on rllib. I am using two models 1.) mixture of gaussian critic with a TorchFC actor 2.) a critic and actor that are both TorchFC, but I will be referencing 2.) for simplicity.

What I have done:
-Looked through all losses and see that the total loss spikes beyond 1e+6 which comes from the kl being also 1e+6 (vf and policy loss were both on the magnitude of 0.01)
-Set kl_coef = 0,
-Revamped the policy loss within a custom policy to be

total_loss = -surrogate_loss + self.config["vf_loss_coeff"] * mean_vf_loss - self.entropy_coeff * mean_entropy

vs rllib’s current implementation which takes the mean after adding all components together. Also tried a different attempt at calculating the logp_ratio (like SB3).
-normalized advantages using a custom callback
-reward wrapper around the environment that reduces reward by a magnitude of 10 or 100 (tried both)
-Hyperparameter altering using clip_param, vf_clip_param, vf_loss_coeff, grad_clip, kl_coeff, lr lowered to 0.0005
-with base policies like ppo_torch_policy I also get this error with a simple model as well
-added assertions which have pointed me toward the logits from the actor network become nan first – when inspecting the weights after training both the actor and the critic have nans for the weights

Hopefully, someone can point out my mistake so I can get back to running simulations and models. If it is a small mistake that I overlooked, please let me know! I have been trying to track down this error for months now.

Thank you for your help!

The error trace is:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
File /usr/local/lib/python3.10/dist-packages/ray/rllib/policy/torch_policy_v2.py:1348, in TorchPolicyV2._multi_gpu_parallel_grad_calc.<locals>._worker(shard_idx, model, sample_batch, device)
   1344 with NullContextManager() if device.type == "cpu" else torch.cuda.device(  # noqa: E501
   1345     device
   1346 ):
   1347     loss_out = force_list(
-> 1348         self.loss(model, self.dist_class, sample_batch)
   1349     )
   1351     # Call Model's custom-loss with Policy loss outputs and
   1352     # train_batch.

File /workspace/pyflyt/policies/ppo_sb3_loss.py:50, in CustomLossPolicy.loss(self, model, dist_class, train_batch)
     46 @override(PPOTorchPolicy)
     47 def loss(self, model, dist_class: Type[ActionDistribution], train_batch: SampleBatch) -> Union[TensorType, List[TensorType]]:
     48     
     49     
---> 50     return self.custom_ppo_loss(model, dist_class, train_batch, self.config)

File /workspace/pyflyt/policies/ppo_sb3_loss.py:68, in CustomLossPolicy.custom_ppo_loss(self, model, dist_class, train_batch, config)
     66 logits, state = model(train_batch)
---> 68 curr_action_dist = dist_class(logits, model)
     70 # RNN case: Mask away 0-padded chunks at end of time axis.

File /usr/local/lib/python3.10/dist-packages/ray/rllib/models/torch/torch_action_dist.py:250, in TorchDiagGaussian.__init__(self, inputs, model, action_space)
    249 self.log_std = log_std
--> 250 self.dist = torch.distributions.normal.Normal(mean, torch.exp(log_std))
    251 # Remember to squeeze action samples in case action space is Box(shape)

File /usr/local/lib/python3.10/dist-packages/torch/distributions/normal.py:56, in Normal.__init__(self, loc, scale, validate_args)
     55     batch_shape = self.loc.size()
---> 56 super().__init__(batch_shape, validate_args=validate_args)

File /usr/local/lib/python3.10/dist-packages/torch/distributions/distribution.py:68, in Distribution.__init__(self, batch_shape, event_shape, validate_args)
     67         if not valid.all():
---> 68             raise ValueError(
     69                 f"Expected parameter {param} "
     70                 f"({type(value).__name__} of shape {tuple(value.shape)}) "
     71                 f"of distribution {repr(self)} "
     72                 f"to satisfy the constraint {repr(constraint)}, "
     73                 f"but found invalid values:\n{value}"
     74             )
     75 super().__init__()

ValueError: Expected parameter loc (Tensor of shape (100, 4)) of distribution Normal(loc: torch.Size([100, 4]), scale: torch.Size([100, 4])) to satisfy the constraint Real(), but found invalid values:
tensor([[nan, nan, nan, nan],
        [nan, nan, nan, nan],

        [nan, nan, nan, nan]], device='cuda:0', grad_fn=<SplitBackward0>)

The above exception was the direct cause of the following exception:

ValueError                                Traceback (most recent call last)
File <timed exec>:64

File /usr/local/lib/python3.10/dist-packages/ray/tune/trainable/trainable.py:331, in Trainable.train(self)
    329 except Exception as e:
    330     skipped = skip_exceptions(e)
--> 331     raise skipped from exception_cause(skipped)
    333 assert isinstance(result, dict), "step() needs to return a dict."
    335 # We do not modify internal state nor update this result if duplicate.

File /usr/local/lib/python3.10/dist-packages/ray/tune/trainable/trainable.py:328, in Trainable.train(self)
    326 start = time.time()
    327 try:
--> 328     result = self.step()
    329 except Exception as e:
    330     skipped = skip_exceptions(e)

File /usr/local/lib/python3.10/dist-packages/ray/rllib/algorithms/algorithm.py:870, in Algorithm.step(self)
    860     (
    861         train_results,
    862         eval_results,
    863         train_iter_ctx,
    864     ) = self._run_one_training_iteration_and_evaluation_in_parallel()
    866 # - No evaluation necessary, just run the next training iteration.
    867 # - We have to evaluate in this training iteration, but no parallelism ->
    868 #   evaluate after the training iteration is entirely done.
    869 else:
--> 870     train_results, train_iter_ctx = self._run_one_training_iteration()
    872 # Sequential: Train (already done above), then evaluate.
    873 if evaluate_this_iter and not self.config.evaluation_parallel_to_training:

File /usr/local/lib/python3.10/dist-packages/ray/rllib/algorithms/algorithm.py:3160, in Algorithm._run_one_training_iteration(self)
   3156 # Try to train one step.
   3157 with self._timers[TRAINING_STEP_TIMER]:
   3158     # TODO (sven): Should we reduce the different
   3159     #  `training_step_results` over time with MetricsLogger.
-> 3160     training_step_results = self.training_step()
   3162 if training_step_results:
   3163     results = training_step_results

File /usr/local/lib/python3.10/dist-packages/ray/rllib/algorithms/ppo/ppo.py:428, in PPO.training_step(self)
    424     return self._training_step_new_api_stack()
    425 # Old and hybrid API stacks (Policy, RolloutWorker, Connector, maybe RLModule,
    426 # maybe Learner).
    427 else:
--> 428     return self._training_step_old_and_hybrid_api_stacks()

File /usr/local/lib/python3.10/dist-packages/ray/rllib/algorithms/ppo/ppo.py:594, in PPO._training_step_old_and_hybrid_api_stacks(self)
    587     train_results = self.learner_group.update_from_batch(
    588         batch=train_batch,
    589         minibatch_size=mini_batch_size_per_learner,
    590         num_iters=self.config.num_sgd_iter,
    591     )
    593 elif self.config.simple_optimizer:
--> 594     train_results = train_one_step(self, train_batch)
    595 else:
    596     train_results = multi_gpu_train_one_step(self, train_batch)

File /usr/local/lib/python3.10/dist-packages/ray/rllib/execution/train_ops.py:56, in train_one_step(algorithm, train_batch, policies_to_train)
     52 with learn_timer:
     53     # Subsample minibatches (size=`sgd_minibatch_size`) from the
     54     # train batch and loop through train batch `num_sgd_iter` times.
     55     if num_sgd_iter > 1 or sgd_minibatch_size > 0:
---> 56         info = do_minibatch_sgd(
     57             train_batch,
     58             {
     59                 pid: local_worker.get_policy(pid)
     60                 for pid in policies_to_train
     61                 or local_worker.get_policies_to_train(train_batch)
     62             },
     63             local_worker,
     64             num_sgd_iter,
     65             sgd_minibatch_size,
     66             [],
     67         )
     68     # Single update step using train batch.
     69     else:
     70         info = local_worker.learn_on_batch(train_batch)

File /usr/local/lib/python3.10/dist-packages/ray/rllib/utils/sgd.py:129, in do_minibatch_sgd(samples, policies, local_worker, num_sgd_iter, sgd_minibatch_size, standardize_fields)
    126     for i in range(num_sgd_iter):
    127         for minibatch in minibatches(batch, sgd_minibatch_size):
    128             results = (
--> 129                 local_worker.learn_on_batch(
    130                     MultiAgentBatch({policy_id: minibatch}, minibatch.count)
    131                 )
    132             )[policy_id]
    133             learner_info_builder.add_learn_on_batch_results(results, policy_id)
    135 learner_info = learner_info_builder.finalize()

File /usr/local/lib/python3.10/dist-packages/ray/rllib/evaluation/rollout_worker.py:806, in RolloutWorker.learn_on_batch(self, samples)
    804             to_fetch[pid] = policy._build_learn_on_batch(builders[pid], batch)
    805         else:
--> 806             info_out[pid] = policy.learn_on_batch(batch)
    808     info_out.update({pid: builders[pid].get(v) for pid, v in to_fetch.items()})
    809 else:

File /usr/local/lib/python3.10/dist-packages/ray/rllib/utils/threading.py:24, in with_lock.<locals>.wrapper(self, *a, **k)
     22 try:
     23     with self._lock:
---> 24         return func(self, *a, **k)
     25 except AttributeError as e:
     26     if "has no attribute '_lock'" in e.args[0]:

File /usr/local/lib/python3.10/dist-packages/ray/rllib/policy/torch_policy_v2.py:715, in TorchPolicyV2.learn_on_batch(self, postprocessed_batch)
    709 self.callbacks.on_learn_on_batch(
    710     policy=self, train_batch=postprocessed_batch, result=learn_stats
    711 )
    713 # Compute gradients (will calculate all losses and `backward()`
    714 # them to get the grads).
--> 715 grads, fetches = self.compute_gradients(postprocessed_batch)
    717 # Step the optimizers.
    718 self.apply_gradients(_directStepOptimizerSingleton)

File /usr/local/lib/python3.10/dist-packages/ray/rllib/utils/threading.py:24, in with_lock.<locals>.wrapper(self, *a, **k)
     22 try:
     23     with self._lock:
---> 24         return func(self, *a, **k)
     25 except AttributeError as e:
     26     if "has no attribute '_lock'" in e.args[0]:

File /usr/local/lib/python3.10/dist-packages/ray/rllib/policy/torch_policy_v2.py:933, in TorchPolicyV2.compute_gradients(self, postprocessed_batch)
    930 self._lazy_tensor_dict(postprocessed_batch, device=self.devices[0])
    932 # Do the (maybe parallelized) gradient calculation step.
--> 933 tower_outputs = self._multi_gpu_parallel_grad_calc([postprocessed_batch])
    935 all_grads, grad_info = tower_outputs[0]
    937 grad_info["allreduce_latency"] /= len(self._optimizers)

File /usr/local/lib/python3.10/dist-packages/ray/rllib/policy/torch_policy_v2.py:1433, in TorchPolicyV2._multi_gpu_parallel_grad_calc(self, sample_batches)
   1431         last_result = results[len(results) - 1]
   1432         if isinstance(last_result[0], ValueError):
-> 1433             raise last_result[0] from last_result[1]
   1434 # Multi device (GPU) case: Parallelize via threads.
   1435 else:
   1436     threads = [
   1437         threading.Thread(
   1438             target=_worker, args=(shard_idx, model, sample_batch, device)
   (...)
   1442         )
   1443     ]

ValueError: Expected parameter loc (Tensor of shape (100, 4)) of distribution Normal(loc: torch.Size([100, 4]), scale: torch.Size([100, 4])) to satisfy the constraint Real(), but found invalid values:
tensor([[nan, nan, nan, nan],

        [nan, nan, nan, nan]], device='cuda:0', grad_fn=<SplitBackward0>)
 tracebackTraceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/ray/rllib/policy/torch_policy_v2.py", line 1348, in _worker
    self.loss(model, self.dist_class, sample_batch)
  File "/workspace/pyflyt/policies/ppo_sb3_loss.py", line 50, in loss
    return self.custom_ppo_loss(model, dist_class, train_batch, self.config)
  File "/workspace/pyflyt/policies/ppo_sb3_loss.py", line 68, in custom_ppo_loss
    curr_action_dist = dist_class(logits, model)
  File "/usr/local/lib/python3.10/dist-packages/ray/rllib/models/torch/torch_action_dist.py", line 250, in __init__
    self.dist = torch.distributions.normal.Normal(mean, torch.exp(log_std))
  File "/usr/local/lib/python3.10/dist-packages/torch/distributions/normal.py", line 56, in __init__
    super().__init__(batch_shape, validate_args=validate_args)
  File "/usr/local/lib/python3.10/dist-packages/torch/distributions/distribution.py", line 68, in __init__
    raise ValueError(
ValueError: Expected parameter loc (Tensor of shape (100, 4)) of distribution Normal(loc: torch.Size([100, 4]), scale: torch.Size([100, 4])) to satisfy the constraint Real(), but found invalid values:

my training block is:

%%time

env_example = env_creator(env_config)
obs_space = env_example.observation_space
action_space = env_example.action_space

config = PPOConfig().training(
    gamma = 0.99,
    lambda_ = 0.95,
    # kl_coeff = 0.5,
    num_sgd_iter = 15,
    # lr_schedule = [[0, 0.0003], [15_000_000, 0.00025], [30_000_000, 0.0002], [50_000_000, 0.0001]],
    lr = 0.0001,
    vf_loss_coeff = 0.5,
    vf_clip_param = 1.0,
    clip_param = 0.2,
    grad_clip_by ='norm', 
    train_batch_size = 2_000, 
    sgd_minibatch_size = 100,
    grad_clip = 0.5,
    kl_coeff = 0.01,
    entropy_coeff = 0.0,
    model = {'custom_model': 'SimpleCustomTorchModel', #SimpleCustomTorchModel MOGTorchModel
           'vf_share_layers': False,
           'fcnet_hiddens': [256,256],
           'fcnet_activation': 'LeakyReLU',
           'custom_model_config': {
                'num_gaussians': 3,
                'num_layers': 2,
                # 'num_outputs': action_space.shape[0],
                # 'parquet_file_name': 'logs/critic_logging_sigma.parquet',
           }
            }
).environment(
    env = 'MAFixedwingDogfightEnv',
    env_config = env_config
).rollouts(
num_rollout_workers = 10
).resources(num_gpus = 1
).callbacks(NormalizeAdvantagesCallback
).multi_agent(
    policies = {
        'policy_1': (CustomLossPolicy, obs_space, action_space, {}),
        'policy_2': (CustomLossPolicy, obs_space, action_space, {}),
    },
    policy_mapping_fn=policy_mapping_fn
)

# analysis = tune.run(
#     'PPO',
#     config=config.to_dict(),
#     stop={'training_iteration':300},
#     checkpoint_freq=10,
#     checkpoint_at_end=True,
#     # local_dir='./ray_results'
# )


algo = config.build()

num_iterations = 1500
results = []

for i in range(num_iterations):
    result = algo.train()
    if i % 10 == 0:
        # print(f"Iteration: {i}, Mean Reward: {result['env_runners']['episode_reward_mean']} episode length: {result['env_runners']['episode_len_mean']}")
        print(f"Iteration: {i}, Policy 1 Mean Reward: {result['env_runners']['policy_reward_mean']['policy_1']}\n"
              f"Iteration: {i}, Policy 2 Mean Reward: {result['env_runners']['policy_reward_mean']['policy_2']}\n"
              f"Iteration: {i}, episode length: {result['env_runners']['episode_len_mean']}")

    results.append([result['env_runners']['episode_reward_mean'], result['env_runners']['episode_len_mean']])

results_df = pd.DataFrame(results)

ray.shutdown()

and the model is:

class SimpleCustomTorchModel(TorchModelV2, nn.Module):
    def __init__(self, obs_space, action_space, num_outputs, model_config, name):
        TorchModelV2.__init__(self, obs_space, action_space, num_outputs, model_config, name)
        nn.Module.__init__(self)

        self.critic_fcnet = TorchFC(obs_space, action_space, 1, model_config, name + "_critic")
        self.actor_fcnet = TorchFC(obs_space, action_space, action_space.shape[0]*2, model_config, name + "_actor")
        self.log_step = 0
        
    @override(TorchModelV2)
    def forward(self, input_dict, state, seq_lens):
        # Get the model output
        logits, _ = self.actor_fcnet(input_dict, state, seq_lens)
        self.value, _ = self.critic_fcnet(input_dict, state, seq_lens)
        self.log_step += 1
        
        return logits, state
        
    @override(TorchModelV2)
    def value_function(self):
        return self.value.squeeze(-1)

# Register the custom model to make it available to Ray/RLlib
ModelCatalog.register_custom_model("SimpleCustomTorchModel", SimpleCustomTorchModel)

Hi @tlaurie99,

Welcome to the forum.

If I had to venture a guess as to where the Nan’s originate it would be here:

250 self.dist = torch.distributions.normal.Normal(mean, torch.exp(log_std))

It has been my experience that when using a continuous action space sometime during training the std logits that parameterize the action distribution can become very negative. Which leads to an std close to zero which causes a Nan when it divides by ~zero on the backward calculation of the normal distribution.

RLLIB is unique the popular frameworks in that it uses the nn policy to generate the log_std values.

If you look at cleanrl or sb3s implementation you will see that they register the log_std as a parameter of the model so they can be learned but not as part of the nn layers.

Since you are already using a custom model you might try implementing this alternative to see if it helps.

cleanrl:

sb3:

2 Likes

Hey @mannyv,

Thanks for the response, I have been going between SB3 and CleanRL to see the differences as well. Today we were looking through the code errors after logging and noticed the origination of the NaNs seems to be in the actor logp ratio calculation which comes from exactly what you are pointing at. I was looking to see what the difference is and saw exactly what you are pointing toward! I will go ahead and give it a try. From your experience, is it worth clamping the log_stds? Or to simply turn them into a parameter of the model? I have also noticed in other continuous environments (like Half Cheetah) I also get the same error after ~50M timesteps – maybe letting the network predict the log_std isn’t the way to go?

Thanks again!

Tyler

1 Like

Hi @tlaurie99 ! Do you have any updates of this topic?

I’m trying to solve the same problem, but in my case only with a multi agent custom environment and build-in model and PPO algorithm.

Thanks in advance!

I had this problem for weeks and I was able to solve it by tinkering with the number of GPUs assigned. I stepped through the code and found that if i assigned too many GPUs or too many learners, then it would result in an NaN being computed in various locations (for example computing the variance of tensor with only one element).

Hence if I actually decreased the number of learners to 1, did not specify the num_gpus parameter in the resources() and set the batch sizes properly (so that each learner had enough samples), I ended up getting the training to work properly.

I have seen this particular error (about the action distribution loc being nan) being posted about by many people and I suspect each person may have a slightly different reason. However, since this fix worked for me, it is probably a good idea to try out these changes.

Sorry for the late reply @hermmanhender but yes as @mannyv stated you either have to add the log standard standard deviations as parameters of the model, or you can clamp them from something like -1 to 1. I have seen having them parameterized allows for less exploration and when I fight them against each other in custom multi-agent scenarios, the clamped version wins more often. Something like this is the parameterized version:

        self.critic_fcnet = TorchFC(obs_space, action_space, 1, model_config, name + "_critic")
        self.actor_means = TorchFC(obs_space, action_space, action_space.shape[0], model_config, name + 
                                   "_actor")
        self.log_std_init = model_config['custom_model_config'].get('log_std_init', 0)
        self.log_stds = nn.Parameter(torch.ones(action_space.shape[0]) * self.log_std_init, requires_grad = True)

    def forward(self, input_dict, state, seq_lens):
        means, _ = self.actor_means(input_dict, state, seq_lens)
        log_stds = self.log_stds.expand_as(means)
        logits = torch.cat((means, log_stds), dim = -1)
        self.value, _ = self.critic_fcnet(input_dict, state, seq_lens)

else you can do the clamped version:

 self.actor_logits = TorchFC(obs_space, action_space, action_space.shape[0]*2, model_config, name + 
                                   "_actor")

    def forward(self, input_dict, state, seq_lens):
        # Get the model output
        logits, _ = self.actor_logits (input_dict, state, seq_lens)
        means, log_stds = torch.chunk(logits, 2, -1)
        log_stds = torch.clamp(log_stds , -1, 1)
        logits = torch.cat((means, log_stds ), dim = -1)
        self.value, _ = self.critic_fcnet(input_dict, state, seq_lens)
        self.log_step += 1

Either of these seems to have helped my issue and I no longer run into it. From lots of logging it seems as though the logp_ratio goes to nan due to the std_dev going extremely small.

1 Like

@tlaurie thanks for your post. @mannyv 's guess is as usual excellent and points directly to the problem. I do not want to get into the details here as @mannyv and you have already brought them up.

I would like to point to two options in RLlib that implement the suggestions given in this thread.

  1. The log standard deviation trained as a simple nn.Parameter (in torch) was already implemented (I admit it was not well documented): using in the model_config_dict {"free_log_std": True} will use the TorchFreeLogStdMLPHead that optimizes the standard deviation in form of a freely moving bias term.
  2. As a response to this thread here, we also added log-std clipping in similar form as shown above to all algorithms with continuous actions. Using in your model_config_dict {"log_std_clip_param": 1} would clamp the log-stds of all actions in between [-1,1] (the default is 20 and if you want to avoid clipping use float("inf")).

Thanks again to all who have contributed to this discussion and helped thereby improving this library!

1 Like

Excellent! Thank you for your feedback and support. Looking forward to using it.

1 Like