PPO+LSTM consistently not working

I am having problems initializing the LSTM layers for a PPO+LSTM in RLlib.
The inputs expected are different from what I give, and I do not understand why. Here my code:


class CustomTorchModel(TorchModelV2, nn.Module):
    def __init__(self, obs_space, action_space, num_outputs, model_config, name):
        TorchModelV2.__init__(self, obs_space, action_space, num_outputs, model_config, name)
        nn.Module.__init__(self)

        self.obs_size = obs_space.shape[0]  # Assuming obs_space is already shaped (12,)
        self.hidden_dim = 128 # Hidden dimension for LSTM and Dense layers
        self.lstm_hidden_state_size = 15  # Size of LSTM hidden state
        
        self.input_layer = nn.Linear(self.obs_size + 1 + action_space.shape[0], self.hidden_dim)
        self.lstm = nn.LSTM(self.hidden_dim, self.lstm_hidden_state_size, batch_first=True)
        self.output_layer = nn.Linear(self.lstm_hidden_state_size, num_outputs)
        self.logits_layer = nn.Linear(self.lstm_hidden_state_size, action_space.shape[0])
        self.log_std = nn.Parameter(torch.zeros(action_space.shape[0]))
    
    @override(TorchModelV2)
    def forward(self, input_dict, state, seq_lens):
        obs = input_dict["obs"]
        print("obs_size", obs.shape)
        prev_reward = input_dict["prev_rewards"].unsqueeze(-1)
        print("prev_reward_size", prev_reward.shape)
        last_actions = input_dict["prev_actions"]
        print("last_actions shape", last_actions.shape)

        x = torch.cat([obs, prev_reward, last_actions], dim=-1)
        x = torch.relu(self.input_layer(x))
        batch_size = x.size(0)
        
                
        print("Input to LSTM x shape:", x.shape)  # Debugging output
        print("State shapes upon receiving in forward:", [s.shape for s in state])  # Debugging output

        if not state or len(state) < 2 or any(s is None for s in state):
            raise ValueError("Invalid state received:", state)
        # Reshape states to correct the unexpected extra dimension
        h0, c0 = state
        if h0.dim() == 4:
            h0 = h0.squeeze(2)  # Remove the unexpected dimension
        if c0.dim() == 4:
            c0 = c0.squeeze(2)  # Remove the unexpected dimension

        x, new_state = self.lstm(x.unsqueeze(0), (h0, c0))
        print("New state", new_state)
        x = x.squeeze(0)

        logits = self.logits_layer(x)
        return logits, new_state

    def value_function(self):
        return self.output_layer(self._last_layer_out)
    
    @override(TorchModelV2)
    def get_initial_state(self):
        # Each state tensor should be 3-D [num_layers, batch_size, hidden_size]
        return [torch.zeros(1, self.lstm_hidden_state_size),
                torch.zeros(1,  self.lstm_hidden_state_size)]

and the results of the printing:

(PPO pid=20660) obs_size torch.Size([32, 12])
(PPO pid=20660) prev_reward_size torch.Size([32, 1])
(PPO pid=20660) last_actions shape torch.Size([32, 2])
(PPO pid=20660) Input to LSTM x shape: torch.Size([32, 128])
(PPO pid=20660) State shapes upon receiving in forward: [torch.Size([32, 1, 15]), torch.Size([32, 1, 15])]

I paste here also the error:

(PPO pid=20660)   File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/torch/nn/modules/rnn.py", line 874, in forward
(PPO pid=20660)     self.check_forward_args(input, hx, batch_sizes)
(PPO pid=20660)   File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/torch/nn/modules/rnn.py", line 790, in check_forward_args
(PPO pid=20660)     self.check_hidden_size(hidden[0], self.get_expected_hidden_size(input, batch_sizes),
(PPO pid=20660)   File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/torch/nn/modules/rnn.py", line 259, in check_hidden_size
(PPO pid=20660)     raise RuntimeError(msg.format(expected_hidden_size, list(hx.size())))
(PPO pid=20660) RuntimeError: Expected hidden[0] size (1, 1, 15), got [32, 1, 15]

and the call:


config = (
    PPOConfig()
    .environment(env='FlowFieldNavEnv-v0')  # Replace "YourEnv-v0" with your actual environment
    .framework("torch")
    
    .rollouts(
        num_rollout_workers=0,  # Number of parallel workers
        num_envs_per_worker=1, # Number of environments each worker simulates,
        rollout_fragment_length=1
    )
    
    .resources(
        num_gpus=0 # Adjust based on your GPU availability
    )
    
    .training(
        # lr=tune.grid_search([1e-4, 1e-5, 1e-6]),
        lr=1e-4,
        train_batch_size=2000,
        sgd_minibatch_size=10,
        num_sgd_iter=10,
        # model={
        #     "fcnet_hiddens": [128, 64, 32],
        #     "fcnet_activation": "tanh",
        #     "use_lstm": True,
        #     "lstm_cell_size": 32,
        #     "lstm_use_prev_action": True,
        #     "lstm_use_prev_reward": True,
        #     "vf_share_layers": True,
        # },
    model = {
        "fcnet_hiddens": [128],
        # "fcnet_activation": "relu",
        "use_lstm": False,
        "custom_model": "custom_torch_model",
        "custom_model_config": {},
    }
    )
    
    .reporting(
        metrics_num_episodes_for_smoothing=5,
        min_sample_timesteps_per_iteration=1000,
    )
    
    .evaluation(
        evaluation_duration=100,
        evaluation_num_workers = 1,
        evaluation_interval=10,
        evaluation_parallel_to_training=True,
        evaluation_config={
            "explore": False  # Evaluation runs with deterministic policy
        },
    )
)
    
    

# Build the trainer with the configured settings


stop_criteria = {
    "training_iteration": 1000
}
# Optional: continue training with multiple iterations or use ray.tune for hyperparameter tuning


# Optional: Use Tune to optimize hyperparameters
from ray import tune

tuner = tune.Tuner(
    "PPO",
    param_space=config,
    run_config=train.RunConfig(
        stop=stop_criteria,
        name="PPO_Flow_field_nav_",
    ),
    tune_config=tune.TuneConfig(
        num_samples=1,
        metric="episode_reward_mean",
        mode="max"
    ),
)

result_grid = tuner.fit()
best_result = result_grid.get_best_result()

I tried everything, from reshaping to squeezing, everything you can think about I tried, but I do not understand why I get 32 instead of 1.
I hope someone can help me :frowning: