1. Severity of the issue: (select one)
[X ] High: Completely blocks me.
Hello, I’m training a single agent environment. PPO with multidiscrete action space. I eventually want to get PPO with custom model with lstm to work because I would like to add in action masking.
When I keep everything constant (training script and training environment) and use no lstm just FFNN model (default to PPO) or
Built in model with lstm (use_lstm=True), I successfully get out different action tuple every time policy/model inference is called. For built in lstm, the state_out are floats
When I use a custom model with an lstm, I get the same action tuple (9, 13, 6) for every inference call even for random observations put in. Also state_out has many 0 -1 1 or other integer values that increase as step count increases (such as 40, then 41, then 42 then 43 …). Seems like there is some sort of saturation with -1 and 1 and tanh activation of 1 ( is -.76159) and some strange large numerical values with 40 41 42 43 that are not seen with the built in model.
.training(model={
“custom_model”: “my_torch_model”,
“use_lstm”: False,
}
Have any ideas why this difference between no lstm/built in model with lstm versus custom model with lstm? How to debug this or find out more why this happening and how to resolve this?
Versions:
ray_2.7.1
python 3.9.0
gymnasium 1.0.0
Windows
custom network architecture
import torch
import torch.nn as nn
import torch.nn.functional as F
from ray.rllib.models import ModelV2
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.models.torch.misc import SlimFC
from ray.rllib.utils.annotations import override
from ray.rllib.models.torch.recurrent_net import RecurrentNetwork as TorchRNN
from ray.rllib.policy.rnn_sequencing import add_time_dimension
====================== Custom LSTM Model ======================
class PPOLSTMModel(TorchRNN,nn.Module):
def __init__(
self,
obs_space,
action_space,
num_outputs,
model_config,
name,
fc_size=256,
lstm_state_size=64,
):
num_outputs = 60
nn.Module.__init__(self)
super().__init__(obs_space, action_space, num_outputs, model_config, name)
self.obs_size = 10032
self.fc_size = fc_size
self.lstm_state_size = lstm_state_size
self.batch_first = False
self.fc1 = nn.Linear(self.obs_size, self.fc_size)
self.linear_relu_stack = nn.Sequential(
nn.Linear(10093, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, self.fc_size)
)
self.lstm = nn.LSTM(self.fc_size, self.lstm_state_size, batch_first=self.batch_first)
# action_branch heads and value heads after lstm layers
self.action_branch = nn.Sequential()
self.value_branch = nn.Sequential()
# Build up the action and value modules
activation = "relu"
self.action_branch.add_module(
"action_branch_linear1",
SlimFC(64, 256, activation_fn=activation),
)
self.value_branch.add_module(
"value_branch_linear1",
SlimFC(64, 256, activation_fn=activation),
)
self.action_branch.add_module(
"action_branch_linear2", SlimFC(256, num_outputs, activation_fn=None)
)
self.value_branch.add_module(
"value_branch_linear2", SlimFC(256, 1, activation_fn=None)
)
# Holds the current "base" output (before logits layer).
self._features = None
@override(ModelV2)
def get_initial_state(self):
h = [
self.fc1.weight.new(1, self.lstm_state_size).zero_(),
self.fc1.weight.new(1, self.lstm_state_size).zero_(),
]
return h
@override(ModelV2)
def value_function(self):
return torch.reshape(self.value_branch(self._features), [-1])
@override(TorchRNN)
def forward(self, input_dict, state, seq_lens):
assert seq_lens is not None
flat_inputs = input_dict["obs_flat"].float()
inputs = add_time_dimension(
flat_inputs,
seq_lens=seq_lens,
framework="torch",
time_major=not self.batch_first,
)
if not state or len(state) < 2:
state = self.get_initial_state()
print('>>> forward(): state was empty, using get_initial_state()')
state_permuted = [state[0].permute(1, 0, 2), state[1].permute(1, 0, 2)]
output, new_state = self.forward_rnn(inputs, state_permuted, seq_lens)
return output, new_state
@override(TorchRNN)
def forward_rnn(self, inputs, state, seq_lens):
"""Feeds `inputs` (B x T x ..) through the Gru Unit.
Returns the resulting outputs as a sequence (B x T x ...).
Values are stored in self._cur_value in simple (B) shape (where B
contains both the B and T dims!).
Returns:
NN Outputs (B x T x ...) as sequence.
The state batches as a List of two items (c- and h-states).
"""
if inputs.dim() == 4:
print("\ TorchRNModel forward_rnn()")
print('inputs.dim() == 4',inputs.dim() == 4)
print('inputs.shape',inputs.shape)
x = nn.functional.relu(self.fc1(inputs))
if state == []:
state = self.get_initial_state()
self._features, [h, c] = self.lstm(
x, [state[0], state[1]]
)
logits = self.action_branch(self._features)
logits = logits.reshape(-1, logits.shape[-1])
return logits, [h, c]
I saw these 2 posts that might be similar with PPO and custom model LSTM however it didn’t have any relevant answers:
@mannyv @sven1977 @christina
Thanks!