Nan in the policy network after training for longer duration

Setup:
I’m using a custom multi-agent hierarchical environment with custom policy networks. The policy network consists of a GCN layer (from pytorch geometric) followed by some fully connected layers. We are using PPO trainer with multiple workers. We are using Ray version 2.2.0

Issue:

After training for around 2K iterations the output of GCN layer contains Nan’s. In most of the cases we see a warning message “RuntimeWarning: overflow encountered in multiply”. We obtained the following stack trace for the warning:

ray.exceptions.RayTaskError(RuntimeWarning): e[36mray::ImplicitFunc.train()e[39m (pid=984967, ip=192.168.50.141, repr=experiment)
  File "/home/student/2020/cs17m20p100001/anaconda3/envs/rllib_env_2.2.0/lib/python3.7/site-packages/ray/tune/trainable/trainable.py", line 367, in train
    raise skipped from exception_cause(skipped)
  File "/home/student/2020/cs17m20p100001/anaconda3/envs/rllib_env_2.2.0/lib/python3.7/site-packages/ray/tune/trainable/function_trainable.py", line 338, in entrypoint
    self._status_reporter.get_checkpoint(),
  File "/home/student/2020/cs17m20p100001/anaconda3/envs/rllib_env_2.2.0/lib/python3.7/site-packages/ray/tune/trainable/function_trainable.py", line 652, in _trainable_func
    output = fn()
  File "experiment_ppo.py", line 59, in experiment
    train_results = train_agent.train()
  File "/home/student/2020/cs17m20p100001/anaconda3/envs/rllib_env_2.2.0/lib/python3.7/site-packages/ray/tune/trainable/trainable.py", line 367, in train
    raise skipped from exception_cause(skipped)
  File "/home/student/2020/cs17m20p100001/anaconda3/envs/rllib_env_2.2.0/lib/python3.7/site-packages/ray/tune/trainable/trainable.py", line 364, in train
    result = self.step()
  File "/home/student/2020/cs17m20p100001/anaconda3/envs/rllib_env_2.2.0/lib/python3.7/site-packages/ray/rllib/algorithms/algorithm.py", line 749, in step
    results, train_iter_ctx = self._run_one_training_iteration()
  File "/home/student/2020/cs17m20p100001/anaconda3/envs/rllib_env_2.2.0/lib/python3.7/site-packages/ray/rllib/algorithms/algorithm.py", line 2623, in _run_one_training_iteration
    results = self.training_step()
  File "/home/student/2020/cs17m20p100001/anaconda3/envs/rllib_env_2.2.0/lib/python3.7/site-packages/ray/rllib/algorithms/ppo/ppo.py", line 326, in training_step
    train_batch = standardize_fields(train_batch, ["advantages"])
  File "/home/student/2020/cs17m20p100001/anaconda3/envs/rllib_env_2.2.0/lib/python3.7/site-packages/ray/rllib/execution/rollout_ops.py", line 126, in standardize_fields
    batch[field] = standardized(batch[field])
  File "/home/student/2020/cs17m20p100001/anaconda3/envs/rllib_env_2.2.0/lib/python3.7/site-packages/ray/rllib/utils/sgd.py", line 24, in standardized
    return (array - array.mean()) / max(1e-4, array.std())
  File "/home/student/2020/cs17m20p100001/anaconda3/envs/rllib_env_2.2.0/lib/python3.7/site-packages/numpy/core/_methods.py", line 270, in _std
    keepdims=keepdims, where=where)
  File "/home/student/2020/cs17m20p100001/anaconda3/envs/rllib_env_2.2.0/lib/python3.7/site-packages/numpy/core/_methods.py", line 236, in _var
    x = um.multiply(x, x, out=x)
RuntimeWarning: overflow encountered in multiply

Similarly the stack trace for the Nan error is as follows:

^[[36mray::ImplicitFunc.train()^[[39m (pid=609848, ip=192.168.50.51, repr=experiment)
  File "/home/ai20btech11004/anaconda3/envs/rllib_env_2.2.0/lib/python3.7/site-packages/ray/tune/trainable/trainable.py", line 367, in train
    raise skipped from exception_cause(skipped)
  File "/home/ai20btech11004/anaconda3/envs/rllib_env_2.2.0/lib/python3.7/site-packages/ray/tune/trainable/function_trainable.py", line 338, in entrypoint
    self._status_reporter.get_checkpoint(),
  File "/home/ai20btech11004/anaconda3/envs/rllib_env_2.2.0/lib/python3.7/site-packages/ray/tune/trainable/function_trainable.py", line 652, in _trainable_func
    output = fn()
  File "experiment_ppo.py", line 59, in experiment
    train_results = train_agent.train()
  File "/home/ai20btech11004/anaconda3/envs/rllib_env_2.2.0/lib/python3.7/site-packages/ray/tune/trainable/trainable.py", line 367, in train
    raise skipped from exception_cause(skipped)
  File "/home/ai20btech11004/anaconda3/envs/rllib_env_2.2.0/lib/python3.7/site-packages/ray/tune/trainable/trainable.py", line 364, in train
    result = self.step()
  File "/home/ai20btech11004/anaconda3/envs/rllib_env_2.2.0/lib/python3.7/site-packages/ray/rllib/algorithms/algorithm.py", line 749, in step
    results, train_iter_ctx = self._run_one_training_iteration()
  File "/home/ai20btech11004/anaconda3/envs/rllib_env_2.2.0/lib/python3.7/site-packages/ray/rllib/algorithms/algorithm.py", line 2623, in _run_one_training_iteration
    results = self.training_step()
  File "/home/ai20btech11004/ML-Register-Allocation/model/RegAlloc/ggnn_drl/rllib_split_model/src/ppo_new.py", line 389, in training_step
    train_results = train_one_step(self, train_batch)
  File "/home/ai20btech11004/anaconda3/envs/rllib_env_2.2.0/lib/python3.7/site-packages/ray/rllib/execution/train_ops.py", line 62, in train_one_step
    [],
  File "/home/ai20btech11004/anaconda3/envs/rllib_env_2.2.0/lib/python3.7/site-packages/ray/rllib/utils/sgd.py", line 130, in do_minibatch_sgd
    MultiAgentBatch({policy_id: minibatch}, minibatch.count)
  File "/home/ai20btech11004/anaconda3/envs/rllib_env_2.2.0/lib/python3.7/site-packages/ray/rllib/evaluation/rollout_worker.py", line 1013, in learn_on_batch
    info_out[pid] = policy.learn_on_batch(batch)
  File "/home/ai20btech11004/anaconda3/envs/rllib_env_2.2.0/lib/python3.7/site-packages/ray/rllib/utils/threading.py", line 24, in wrapper
    return func(self, *a, **k)
  File "/home/ai20btech11004/anaconda3/envs/rllib_env_2.2.0/lib/python3.7/site-packages/ray/rllib/policy/torch_policy_v2.py", line 616, in learn_on_batch
    grads, fetches = self.compute_gradients(postprocessed_batch)
  File "/home/ai20btech11004/anaconda3/envs/rllib_env_2.2.0/lib/python3.7/site-packages/ray/rllib/utils/threading.py", line 24, in wrapper
    return func(self, *a, **k)
  File "/home/ai20btech11004/anaconda3/envs/rllib_env_2.2.0/lib/python3.7/site-packages/ray/rllib/policy/torch_policy_v2.py", line 816, in compute_gradients
    tower_outputs = self._multi_gpu_parallel_grad_calc([postprocessed_batch])
  File "/home/ai20btech11004/anaconda3/envs/rllib_env_2.2.0/lib/python3.7/site-packages/ray/rllib/policy/torch_policy_v2.py", line 1212, in _multi_gpu_parallel_grad_calc
    raise last_result[0] from last_result[1]
ValueError: Nan in select node model input after ggnn
 tracebackTraceback (most recent call last):
  File "/home/ai20btech11004/anaconda3/envs/rllib_env_2.2.0/lib/python3.7/site-packages/ray/rllib/policy/torch_policy_v2.py", line 1128, in _worker
    self.loss(model, self.dist_class, sample_batch)
  File "/home/ai20btech11004/anaconda3/envs/rllib_env_2.2.0/lib/python3.7/site-packages/ray/rllib/algorithms/ppo/ppo_torch_policy.py", line 87, in loss
    logits, state, extra_infos = model(train_batch)
  File "/home/ai20btech11004/anaconda3/envs/rllib_env_2.2.0/lib/python3.7/site-packages/ray/rllib/models/modelv2.py", line 260, in __call__
    res = self.forward(restored, state or [], seq_lens)
  File "/home/ai20btech11004/ML-Register-Allocation/model/RegAlloc/ggnn_drl/rllib_split_model/src/model.py", line 155, in forward
assert not torch.isnan(node_mat).any(), "Nan in select node model input after ggnn"
AssertionError: Nan in select node model input after ggnn

We suspect that the Nan is cause of the overflow encountered. We also observed that overflow was encountered while standardizing the advantages.

Can someone please suggest what could be the cause of the issue and how can we fix it ?