PPO custom model with LSTM

1. Severity of the issue: (select one)
[X ] High: Completely blocks me.

Hello, I’m training a single agent environment. PPO with multidiscrete action space. I eventually want to get PPO with custom model with lstm to work because I would like to add in action masking.

When I keep everything constant (training script and training environment) and use no lstm just FFNN model (default to PPO) or
Built in model with lstm (use_lstm=True), I successfully get out different action tuple every time policy/model inference is called. For built in lstm, the state_out are floats

When I use a custom model with an lstm, I get the same action tuple (9, 13, 6) for every inference call even for random observations put in. Also state_out has many 0 -1 1 or other integer values that increase as step count increases (such as 40, then 41, then 42 then 43 …). Seems like there is some sort of saturation with -1 and 1 and tanh activation of 1 ( is -.76159) and some strange large numerical values with 40 41 42 43 that are not seen with the built in model.
.training(model={
“custom_model”: “my_torch_model”,
“use_lstm”: False,
}

Have any ideas why this difference between no lstm/built in model with lstm versus custom model with lstm? How to debug this or find out more why this happening and how to resolve this?

Versions:
ray_2.7.1
python 3.9.0
gymnasium 1.0.0
Windows

custom network architecture

import torch
import torch.nn as nn
import torch.nn.functional as F
from ray.rllib.models import ModelV2
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.models.torch.misc import SlimFC
from ray.rllib.utils.annotations import override
from ray.rllib.models.torch.recurrent_net import RecurrentNetwork as TorchRNN
from ray.rllib.policy.rnn_sequencing import add_time_dimension

 ====================== Custom LSTM Model ======================
class PPOLSTMModel(TorchRNN,nn.Module):
    def __init__(
        self,
        obs_space,
        action_space,
        num_outputs,
        model_config,
        name,
        fc_size=256,
        lstm_state_size=64,
    ):
        num_outputs = 60 
        nn.Module.__init__(self)
        super().__init__(obs_space, action_space, num_outputs, model_config, name)

        self.obs_size = 10032 
        self.fc_size = fc_size
        self.lstm_state_size = lstm_state_size

        self.batch_first = False 

        self.fc1 = nn.Linear(self.obs_size, self.fc_size)

        self.linear_relu_stack = nn.Sequential(
            nn.Linear(10093, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, self.fc_size)
        )

        self.lstm = nn.LSTM(self.fc_size, self.lstm_state_size, batch_first=self.batch_first)

        # action_branch heads and value heads after lstm layers
        self.action_branch = nn.Sequential()
        self.value_branch = nn.Sequential()
        
        # Build up the action and value modules
        activation = "relu"

        self.action_branch.add_module(
            "action_branch_linear1",
            SlimFC(64, 256, activation_fn=activation),
        )
        self.value_branch.add_module(
            "value_branch_linear1",
            SlimFC(64, 256, activation_fn=activation),
        )

        self.action_branch.add_module(
            "action_branch_linear2", SlimFC(256, num_outputs, activation_fn=None)
        )
        self.value_branch.add_module(
            "value_branch_linear2", SlimFC(256, 1, activation_fn=None)
        )

        # Holds the current "base" output (before logits layer).
        self._features = None

    @override(ModelV2)
    def get_initial_state(self):
        h = [
            self.fc1.weight.new(1, self.lstm_state_size).zero_(),
            self.fc1.weight.new(1, self.lstm_state_size).zero_(),
        ]
        return h

    @override(ModelV2)
    def value_function(self):
        return torch.reshape(self.value_branch(self._features), [-1])

    @override(TorchRNN)
    def forward(self, input_dict, state, seq_lens):
        assert seq_lens is not None
        flat_inputs = input_dict["obs_flat"].float()
        inputs = add_time_dimension(
            flat_inputs,
            seq_lens=seq_lens,
            framework="torch",
            time_major=not self.batch_first,
        )
        if not state or len(state) < 2:
            state = self.get_initial_state()
            print('>>> forward(): state was empty, using get_initial_state()')
        state_permuted = [state[0].permute(1, 0, 2), state[1].permute(1, 0, 2)]
        output, new_state = self.forward_rnn(inputs, state_permuted, seq_lens)
        return output, new_state

    @override(TorchRNN)
    def forward_rnn(self, inputs, state, seq_lens):
        """Feeds `inputs` (B x T x ..) through the Gru Unit.

        Returns the resulting outputs as a sequence (B x T x ...).
        Values are stored in self._cur_value in simple (B) shape (where B
        contains both the B and T dims!).

        Returns:
            NN Outputs (B x T x ...) as sequence.
            The state batches as a List of two items (c- and h-states).
        """
        if inputs.dim() == 4:
            print("\ TorchRNModel forward_rnn()")
            print('inputs.dim() == 4',inputs.dim() == 4)
            print('inputs.shape',inputs.shape)

        x = nn.functional.relu(self.fc1(inputs))
        if state == []:
            state = self.get_initial_state()

        self._features, [h, c] = self.lstm(
            x, [state[0], state[1]]
        )

        logits = self.action_branch(self._features)

        logits = logits.reshape(-1, logits.shape[-1])
        return logits, [h, c]

I saw these 2 posts that might be similar with PPO and custom model LSTM however it didn’t have any relevant answers:

@mannyv @sven1977 @christina
Thanks!