Nan or Inf issue with ppo and action masking system

Hi everyone,
I hope someone can help me with an issue that has been bothering me for the past week. I’m currently training a single agent with PPO algorithm on a custom environment, using a custom neural network architecture with an action masking system.
However, I often receive the warning “NaN or Inf found in input tensor,” and I still don’t understand why. I have thoroughly checked my custom environment and fixed it so that it should be impossible for it to return NaN or Inf values in the state vector.

Here is my training script:

import ray
from ray.rllib.algorithms.ppo import PPO
from Custom_Env import Custom_Env
from Custom_model import Custom_Model
from tqdm import tqdm
import numpy as np

episodes = 2500
learning_rate = 0.0005
discount_factor = 0.99

env_config = {'cl':False}
cd = Custom_Env(env_config)

config = {
    "env": Custom_Env,
    "env_config": env_config,
    "framework": "torch",
    "lr": learning_rate,
    "gamma": discount_factor,
    "num_workers": 16,
    "num_envs_per_worker": 1,
    "rollout_fragment_length": 100,
    "batch_mode": "complete_episodes",
    "model": {
        "custom_model": Custom_Model,
        "custom_model_config": {
            "branches_input": [1, 1, 10, 20, 10, 10, 30, 30, 30],
            "branches_output": [1, 1, 4, 4, 4, 4, 16, 16, 16],
            "action_mask_dim": cd.n_clusters,
            "combined_hiddens": [128, 128],
            "value_branch": [128, 1]
        }
    },
    "num_gpus": 1,
    "train_batch_size": 1600,
    "sgd_minibatch_size": 800,
    "num_sgd_iter": 3,
    "clip_param": 0.1,
    "grad_clip": 40.0,
    "vf_clip_param": 10.0,
    "entropy_coeff": 0,
    "lambda": 0.99,
    "use_gae": True
}

def run():
    ray.init(ignore_reinit_error=True)
    trainer = PPO(config=config)
    for episode in tqdm(range(episodes)):
        result = trainer.train()
    trainer.save(f'model')

run()

and here’s my custom model:

import torch
import torch.nn as nn
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2

class Custom_Model(TorchModelV2, nn.Module):
    def __init__(self, obs_space, action_space, num_outputs, model_config, name, **kwargs):
        TorchModelV2.__init__(self, obs_space, action_space, num_outputs, model_config, name)
        nn.Module.__init__(self)

        self.num_outputs = num_outputs
        combined_input = kwargs['input_size']
        combined_hiddens = kwargs['combined_hiddens']
        value_branch = kwargs.get('value_branch', None)

        layers = [nn.Linear(combined_input, combined_hiddens[0]), nn.ReLU()]
        for i in range(1, len(combined_hiddens)):
            layers.append(nn.Linear(combined_hiddens[i-1], combined_hiddens[i]))
            layers.append(nn.ReLU())
        layers.append(nn.Linear(combined_hiddens[-1], num_outputs))
        self.network = nn.Sequential(*layers)

        if value_branch:
            value_layers = [nn.Linear(combined_input, value_branch[0]), nn.ReLU()]
            for i in range(1, len(value_branch)):
                value_layers.append(nn.Linear(value_branch[i-1], value_branch[i]))
                value_layers.append(nn.ReLU())
            value_layers.append(nn.Linear(value_branch[-1], 1))
            self.value_branch = nn.Sequential(*value_layers)

    def forward(self, input_dict, state, seq_lens):
        x = input_dict["obs"]
        features = x[:, :-self.num_outputs]
        action_mask = x[:, -self.num_outputs:]
        invalid_mask = (action_mask.sum(dim=1, keepdim=True) == 0).float()
        corrected_action_mask = action_mask + invalid_mask * torch.ones_like(action_mask)

        logits = self.network(features)

        corrected_action_mask = torch.clamp(corrected_action_mask, min=1e-10)
        inf_mask = torch.log(corrected_action_mask)
        masked_output = logits + inf_mask

        self._value_output = self.value_branch(features) if hasattr(self, 'value_branch') else None

        return masked_output, state

    def value_function(self):
        return self._value_output.squeeze(-1) if self._value_output is not None else torch.zeros_like(self._value_output)

Could anyone please help me?
Thank you in advance!

Best regards,
L.