PPO+LSTM custom model implementation problem ray2.10.0

I have a problem when trying to implement my custom model. The impression I have is that it calls not only my custom model but someway calls also the methods in the source LSTM Wrapper which are not overridden. I paste here my custom model, and the errors I get :frowning:

class CustomTorchModel(TorchModelV2):
    def __init__(self, obs_space, action_space, num_outputs, model_config, name):
        nn.Module.__init__(self)
        super(CustomTorchModel, self).__init__(obs_space, action_space, num_outputs, model_config, name)

        # Configuration
        self.obs_size = obs_space.shape[0]
        self.action_size = action_space.n if hasattr(action_space, "n") else action_space.shape[0]
        self.input_size = self.obs_size + self.action_size + 1  # obs + last action + last reward
        self.lstm_hidden_size = model_config["lstm_cell_size"]
        self.num_layers = 1

        # Fully Connected Layers
        self.fc1 = nn.Linear(self.input_size, 64)
        self.fc2 = nn.Linear(64, 15)

        # LSTM Configuration
        self.lstm = nn.LSTM(15, self.lstm_hidden_size, batch_first=True)

        # Output Layers
        self.policy_mean = nn.Linear(self.lstm_hidden_size, self.action_size)
        self.policy_std = nn.Linear(self.lstm_hidden_size, self.action_size)
        self.value_head = nn.Linear(self.lstm_hidden_size, 1)

        # Initial Value Placeholder
        self._value_out = None

        # View Requirements for RNNs
        if model_config["lstm_use_prev_action"]:
            self.view_requirements[SampleBatch.PREV_ACTIONS] = ViewRequirement(
                "actions", shift=-1, space=action_space
            )
        if model_config["lstm_use_prev_reward"]:
            self.view_requirements[SampleBatch.PREV_REWARDS] = ViewRequirement(
                "rewards", shift=-1
            )

    @override(TorchModelV2)
    def forward(self, input_dict: Dict[str, TensorType], state: List[TensorType], seq_lens: TensorType) -> Tuple[TensorType, List[TensorType]]:
        print(f"Forward call with state: {state} and seq_lens: {seq_lens}")
        print(f"Shape of obs: {input_dict['obs'].shape}")
        print(f"Shape of prev_actions: {input_dict['prev_actions'].shape}")
        print(f"Shape of prev_rewards: {input_dict['prev_rewards'].shape}")

        # Concatenate Inputs
        x = torch.cat([
            input_dict["obs"],
            input_dict["prev_actions"],
            input_dict["prev_rewards"].unsqueeze(-1)
        ], dim=1)

        assert x.shape[-1] == self.input_size, f"Input size mismatch: Expected {self.input_size}, got {x.shape[-1]}"

        print(f"Concatenated input x shape: {x.shape}")

        # Fully Connected Layers
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        print(f"Post-FC layer shape: {x.shape}")

        x = x.unsqueeze(1)  # For LSTM compatibility
        print(f"Input to LSTM shape: {x.shape}")

        return self.forward_rnn(x, state, seq_lens)
        
     @override(RecurrentNetwork)
    def forward_rnn(self, inputs: TensorType, state: List[TensorType], seq_lens: TensorType) -> Tuple[TensorType, List[TensorType]]:
        print(f"Shape of inputs before LSTM: {inputs.shape}")
        print(f"State shapes: {[s.shape for s in state]}")

        if state is None or len(state) == 0:
            h0 = torch.zeros(self.num_layers, inputs.size(0), self.lstm_hidden_size, device=inputs.device)
            c0 = torch.zeros(self.num_layers, inputs.size(0), self.lstm_hidden_size, device=inputs.device)
            print("Generated new initial states for LSTM")
        else:
            h0, c0 = state[0],state[1]

        x, (hn, cn) = self.lstm(inputs, (h0, c0))

        print(f"Output from LSTM x shape: {x.shape}, hn shape: {hn.shape}, cn shape: {cn.shape}")

        x = x.squeeze(1)  # Remove sequence length dimension
        print(f"Squeezed output shape: {x.shape}")

        action_mean = self.policy_mean(x)
        action_std = torch.exp(self.policy_std(x))
        print(f"Action mean shape: {action_mean.shape}, Action std shape: {action_std.shape}")

        self._value_out = self.value_head(x)
        print(f"Value output shape: {self._value_out.shape}")

        return torch.cat([action_mean, action_std], dim=-1), [hn, cn]

    @override(TorchModelV2)
    def value_function(self):
        assert self._value_out is not None, "Must call forward() first"
        return self._value_out.squeeze(1)

    def get_initial_state(self) -> Union[List[np.ndarray], List[TensorType]]:
        return [torch.zeros(self.num_layers, 1, self.lstm_hidden_size),
                torch.zeros(self.num_layers, 1, self.lstm_hidden_size)]

Here the errors and the outputs of my printings:

2024-05-09 14:16:37,584	ERROR tune_controller.py:1332 -- Trial task failed for trial PPO_FlowFieldNavEnv-v0_f861a_00000
Traceback (most recent call last):
  File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/ray/air/execution/_internal/event_manager.py", line 110, in resolve_future
    result = ray.get(future)
             ^^^^^^^^^^^^^^^
  File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/ray/_private/auto_init_hook.py", line 21, in auto_init_wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/ray/_private/client_mode_hook.py", line 103, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/ray/_private/worker.py", line 2667, in get
    values, debugger_breakpoint = worker.get_objects(object_refs, timeout=timeout)
                                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/ray/_private/worker.py", line 866, in get_objects
    raise value
ray.exceptions.RayActorError: The actor died because of an error raised in its creation task, ray::PPO.__init__() (pid=12275, ip=127.0.0.1, actor_id=15c768a5baccc2f38a7a7dce01000000, repr=PPO)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/ray/rllib/algorithms/algorithm.py", line 533, in __init__
    super().__init__(
  File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/ray/tune/trainable/trainable.py", line 161, in __init__
    self.setup(copy.deepcopy(self.config))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/ray/rllib/algorithms/algorithm.py", line 631, in setup
    self.workers = WorkerSet(
                   ^^^^^^^^^^
  File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/ray/rllib/evaluation/worker_set.py", line 159, in __init__
    self._setup(
  File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/ray/rllib/evaluation/worker_set.py", line 250, in _setup
    self._local_worker = self._make_worker(
                         ^^^^^^^^^^^^^^^^^^
  File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/ray/rllib/evaluation/worker_set.py", line 1016, in _make_worker
    worker = cls(
             ^^^^
  File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/ray/rllib/evaluation/rollout_worker.py", line 535, in __init__
    self._update_policy_map(policy_dict=self.policy_dict)
  File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/ray/rllib/evaluation/rollout_worker.py", line 1743, in _update_policy_map
    self._build_policy_map(
  File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/ray/rllib/evaluation/rollout_worker.py", line 1854, in _build_policy_map
    new_policy = create_policy_for_framework(
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/ray/rllib/utils/policy.py", line 141, in create_policy_for_framework
    return policy_class(observation_space, action_space, merged_config)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/ray/rllib/algorithms/ppo/ppo_torch_policy.py", line 64, in __init__
    self._initialize_loss_from_dummy_batch()
  File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/ray/rllib/policy/policy.py", line 1396, in _initialize_loss_from_dummy_batch
    actions, state_outs, extra_outs = self.compute_actions_from_input_dict(
                                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/ray/rllib/policy/torch_policy_v2.py", line 557, in compute_actions_from_input_dict
    return self._compute_action_helper(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/ray/rllib/utils/threading.py", line 24, in wrapper
    return func(self, *a, **k)
           ^^^^^^^^^^^^^^^^^^^
  File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/ray/rllib/policy/torch_policy_v2.py", line 1260, in _compute_action_helper
    dist_inputs, state_out = self.model(input_dict, state_batches, seq_lens)
                             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/ray/rllib/models/modelv2.py", line 255, in __call__
    res = self.forward(restored, state or [], seq_lens)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/ray/rllib/models/torch/recurrent_net.py", line 219, in forward
    wrapped_out, _ = self._wrapped_forward(input_dict, [], None)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/federicatonti/Desktop/KTH_work/2D_obstacles/PPO/navigation/untitled2.py", line 167, in forward
    return self.forward_rnn(x, state, seq_lens)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/ray/rllib/models/torch/recurrent_net.py", line 274, in forward_rnn
    inputs, [torch.unsqueeze(state[0], 0), torch.unsqueeze(state[1], 0)]
                             ~~~~~^^^
IndexError: list index out of range
2024-05-09 14:16:37,588	INFO tune.py:1016 -- Wrote the latest version of all result files and experiment state to '/Users/federicatonti/ray_results/PPO_Flow_field_nav_' in 0.0019s.
2024-05-09 14:16:37,589	ERROR tune.py:1044 -- Trials did not complete: [PPO_FlowFieldNavEnv-v0_f861a_00000]
2024-05-09 14:16:37,597	WARNING experiment_analysis.py:190 -- Failed to fetch metrics for 1 trial(s):
- PPO_FlowFieldNavEnv-v0_f861a_00000: FileNotFoundError('Could not fetch metrics for PPO_FlowFieldNavEnv-v0_f861a_00000: both result.json and progress.csv were not found at /Users/federicatonti/ray_results/PPO_Flow_field_nav_/PPO_FlowFieldNavEnv-v0_f861a_00000_0_2024-05-09_14-16-31')

Trial PPO_FlowFieldNavEnv-v0_f861a_00000 errored after 0 iterations at 2024-05-09 14:16:37. Total running time: 5s
Error file: /tmp/ray/session_2024-05-09_14-16-29_794196_11064/artifacts/2024-05-09_14-16-31/PPO_Flow_field_nav_/driver_artifacts/PPO_FlowFieldNavEnv-v0_f861a_00000_0_2024-05-09_14-16-31/error.txt

Trial status: 1 ERROR
Current time: 2024-05-09 14:16:37. Total running time: 5s
Logical resource usage: 0/12 CPUs, 0/0 GPUs
+-----------------------------------------------+
| Trial name                           status   |
+-----------------------------------------------+
| PPO_FlowFieldNavEnv-v0_f861a_00000   ERROR    |
+-----------------------------------------------+

Number of errored trials: 1
+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
| Trial name                             # failures   error file                                                                                                                                                                              |
+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
| PPO_FlowFieldNavEnv-v0_f861a_00000              1   /tmp/ray/session_2024-05-09_14-16-29_794196_11064/artifacts/2024-05-09_14-16-31/PPO_Flow_field_nav_/driver_artifacts/PPO_FlowFieldNavEnv-v0_f861a_00000_0_2024-05-09_14-16-31/error.txt |
+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+

(PPO pid=12275) Forward call with state: [] and seq_lens: None
(PPO pid=12275) Shape of obs: torch.Size([32, 12])
(PPO pid=12275) Shape of prev_actions: torch.Size([32, 2])
(PPO pid=12275) Shape of prev_rewards: torch.Size([32])
(PPO pid=12275) Concatenated input x shape: torch.Size([32, 15])
(PPO pid=12275) Post-FC layer shape: torch.Size([32, 15])
(PPO pid=12275) Input to LSTM shape: torch.Size([32, 1, 15])
(PPO pid=12275) Shape of inputs before LSTM: torch.Size([32, 1, 15])
(PPO pid=12275) State shapes: []
(PPO pid=12275) Exception raised in creation task: The actor died because of an error raised in its creation task, ray::PPO.__init__() (pid=12275, ip=127.0.0.1, actor_id=15c768a5baccc2f38a7a7dce01000000, repr=PPO)
(PPO pid=12275)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(PPO pid=12275)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(PPO pid=12275)   File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/ray/rllib/algorithms/algorithm.py", line 533, in __init__
(PPO pid=12275)     super().__init__(
(PPO pid=12275)   File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/ray/tune/trainable/trainable.py", line 161, in __init__
(PPO pid=12275)     self.setup(copy.deepcopy(self.config))
(PPO pid=12275)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(PPO pid=12275)   File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/ray/rllib/algorithms/algorithm.py", line 631, in setup
(PPO pid=12275)     self.workers = WorkerSet(
(PPO pid=12275)                    ^^^^^^^^^^
(PPO pid=12275)   File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/ray/rllib/evaluation/worker_set.py", line 159, in __init__
(PPO pid=12275)     self._setup(
(PPO pid=12275)   File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/ray/rllib/evaluation/worker_set.py", line 250, in _setup
(PPO pid=12275)     self._local_worker = self._make_worker(
(PPO pid=12275)                          ^^^^^^^^^^^^^^^^^^
(PPO pid=12275)   File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/ray/rllib/evaluation/worker_set.py", line 1016, in _make_worker
(PPO pid=12275)     worker = cls(
(PPO pid=12275)              ^^^^
(PPO pid=12275)   File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/ray/rllib/evaluation/rollout_worker.py", line 535, in __init__
(PPO pid=12275)     self._update_policy_map(policy_dict=self.policy_dict)
(PPO pid=12275)   File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/ray/rllib/evaluation/rollout_worker.py", line 1743, in _update_policy_map
(PPO pid=12275)     self._build_policy_map(
(PPO pid=12275)   File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/ray/rllib/evaluation/rollout_worker.py", line 1854, in _build_policy_map
(PPO pid=12275)     new_policy = create_policy_for_framework(
(PPO pid=12275)                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(PPO pid=12275)   File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/ray/rllib/utils/policy.py", line 141, in create_policy_for_framework
(PPO pid=12275)     return policy_class(observation_space, action_space, merged_config)
(PPO pid=12275)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(PPO pid=12275)   File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/ray/rllib/algorithms/ppo/ppo_torch_policy.py", line 64, in __init__
(PPO pid=12275)     self._initialize_loss_from_dummy_batch()
(PPO pid=12275)   File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/ray/rllib/policy/policy.py", line 1396, in _initialize_loss_from_dummy_batch
(PPO pid=12275)     actions, state_outs, extra_outs = self.compute_actions_from_input_dict(
(PPO pid=12275)                                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(PPO pid=12275)   File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/ray/rllib/policy/torch_policy_v2.py", line 557, in compute_actions_from_input_dict
(PPO pid=12275)     return self._compute_action_helper(
(PPO pid=12275)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(PPO pid=12275)   File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/ray/rllib/utils/threading.py", line 24, in wrapper
(PPO pid=12275)     return func(self, *a, **k)
(PPO pid=12275)            ^^^^^^^^^^^^^^^^^^^
(PPO pid=12275)   File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/ray/rllib/policy/torch_policy_v2.py", line 1260, in _compute_action_helper
(PPO pid=12275)     dist_inputs, state_out = self.model(input_dict, state_batches, seq_lens)
(PPO pid=12275)                              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(PPO pid=12275)   File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/ray/rllib/models/modelv2.py", line 255, in __call__
(PPO pid=12275)     res = self.forward(restored, state or [], seq_lens)
(PPO pid=12275)           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(PPO pid=12275)   File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/ray/rllib/models/torch/recurrent_net.py", line 219, in forward
(PPO pid=12275)     wrapped_out, _ = self._wrapped_forward(input_dict, [], None)
(PPO pid=12275)                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(PPO pid=12275)   File "/Users/federicatonti/Desktop/KTH_work/2D_obstacles/PPO/navigation/untitled2.py", line 167, in forward
(PPO pid=12275)     return self.forward_rnn(x, state, seq_lens)
(PPO pid=12275)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(PPO pid=12275)   File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/ray/rllib/models/torch/recurrent_net.py", line 274, in forward_rnn
(PPO pid=12275)     inputs, [torch.unsqueeze(state[0], 0), torch.unsqueeze(state[1], 0)]
(PPO pid=12275)                              ~~~~~^^^
(PPO pid=12275) IndexError: list index out of range

And here the shapes of the tensors:

PPO pid=12275) Forward call with state: [] and seq_lens: None
(PPO pid=12275) Shape of obs: torch.Size([32, 12])
(PPO pid=12275) Shape of prev_actions: torch.Size([32, 2])
(PPO pid=12275) Shape of prev_rewards: torch.Size([32])
(PPO pid=12275) Concatenated input x shape: torch.Size([32, 15])
(PPO pid=12275) Post-FC layer shape: torch.Size([32, 15])
(PPO pid=12275) Input to LSTM shape: torch.Size([32, 1, 15])
(PPO pid=12275) Shape of inputs before LSTM: torch.Size([32, 1, 15])
(PPO pid=12275) State shapes: []

I don’t understand why the state is empty :frowning:

Hj @Federica_Tonti,

If you have a custom reccurent model you should set use_lstm to False in the model config.

Thanks a lot for your reply!!! Just a stupid question then, in the logs of the model I see that then use_lstm is False, this does not imply that RLlib is not using the LSTM, right? Sorry, but I am rather new to RLlib!

@mannyv sorry for disturbing you again, maybe you can help me further, i am quite desperate at the moment. I removed that but probably I already messed up with other things,I did manage to make the code run but now it is stucked in other points.
I tried to simplify a bit my model like this:


class CustomTorchModel(TorchModelV2, nn.Module):
    def __init__(self, obs_space, action_space, num_outputs, model_config, name):
        TorchModelV2.__init__(self, obs_space, action_space, num_outputs, model_config, name)
        nn.Module.__init__(self)

        self.obs_size = obs_space.shape[0]  # Assuming obs_space is already shaped (12,)
        self.hidden_dim = 32  # Hidden dimension for LSTM and Dense layers
        self.lstm_hidden_state_size = 32  # Size of LSTM hidden state
        
        # Input will be observation + previous reward + last actions (2-dimensional)
        self.input_layer = nn.Linear(self.obs_size + 1 + action_space.shape[0], self.hidden_dim)

        # LSTM layer
        self.lstm = nn.LSTM(self.hidden_dim, self.lstm_hidden_state_size, batch_first=True)

        # Output layer for action means
        self.logits_layer = nn.Linear(self.lstm_hidden_state_size, num_outputs)

        # Logits (action mean) and action log std for the continuous action space
        self.log_std = nn.Parameter(torch.zeros(action_space.shape[0]))
    
    @override(TorchModelV2)
    def forward(self, input_dict, state, seq_lens):
        obs = input_dict["obs"]
        prev_reward = input_dict["prev_rewards"].unsqueeze(-1)
        last_actions = input_dict["prev_actions"]

        # Concatenate observations with previous reward and last actions
        x = torch.cat([obs, prev_reward, last_actions], dim=-1)
        x = torch.relu(self.input_layer(x))

        # LSTM expects inputs of shape (batch, seq, feature)
        x, new_state = self.lstm(x.unsqueeze(0), state)
        x = x.squeeze(0)

        # Output the action logits
        logits = self.logits_layer(x)
        return logits, new_state
    
    @override(TorchModelV2)
    def value_function(self):
        # Compute value from the last layer outputs
        return self.output_layer(self._last_layer_out)

    def get_initial_state(self):
        # This method should return the initial state of the LSTM
        return [torch.zeros(1, self.lstm_hidden_state_size),
                torch.zeros(1, self.lstm_hidden_state_size)]

but now I repeatedly get:

(PPO pid=19559) Exception raised in creation task: The actor died because of an error raised in its creation task, ray::PPO.__init__() (pid=19559, ip=127.0.0.1, actor_id=3bf1ac18d301b43f1a14dbcb01000000, repr=PPO)
(PPO pid=19559)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(PPO pid=19559)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(PPO pid=19559)   File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/ray/rllib/algorithms/algorithm.py", line 533, in __init__
(PPO pid=19559)     super().__init__(
(PPO pid=19559)   File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/ray/tune/trainable/trainable.py", line 161, in __init__
(PPO pid=19559)     self.setup(copy.deepcopy(self.config))
(PPO pid=19559)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(PPO pid=19559)   File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/ray/rllib/algorithms/algorithm.py", line 631, in setup
(PPO pid=19559)     self.workers = WorkerSet(
(PPO pid=19559)                    ^^^^^^^^^^
(PPO pid=19559)   File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/ray/rllib/evaluation/worker_set.py", line 159, in __init__
(PPO pid=19559)     self._setup(
(PPO pid=19559)   File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/ray/rllib/evaluation/worker_set.py", line 250, in _setup
(PPO pid=19559)     self._local_worker = self._make_worker(
(PPO pid=19559)                          ^^^^^^^^^^^^^^^^^^
(PPO pid=19559)   File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/ray/rllib/evaluation/worker_set.py", line 1016, in _make_worker
(PPO pid=19559)     worker = cls(
(PPO pid=19559)              ^^^^
(PPO pid=19559)   File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/ray/rllib/evaluation/rollout_worker.py", line 535, in __init__
(PPO pid=19559)     self._update_policy_map(policy_dict=self.policy_dict)
(PPO pid=19559)   File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/ray/rllib/evaluation/rollout_worker.py", line 1743, in _update_policy_map
(PPO pid=19559)     self._build_policy_map(
(PPO pid=19559)   File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/ray/rllib/evaluation/rollout_worker.py", line 1854, in _build_policy_map
(PPO pid=19559)     new_policy = create_policy_for_framework(
(PPO pid=19559)                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(PPO pid=19559)   File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/ray/rllib/utils/policy.py", line 141, in create_policy_for_framework
(PPO pid=19559)     return policy_class(observation_space, action_space, merged_config)
(PPO pid=19559)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(PPO pid=19559)   File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/ray/rllib/algorithms/ppo/ppo_torch_policy.py", line 64, in __init__
(PPO pid=19559)     self._initialize_loss_from_dummy_batch()
(PPO pid=19559)   File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/ray/rllib/policy/policy.py", line 1396, in _initialize_loss_from_dummy_batch
(PPO pid=19559)     actions, state_outs, extra_outs = self.compute_actions_from_input_dict(
(PPO pid=19559)                                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(PPO pid=19559)   File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/ray/rllib/policy/torch_policy_v2.py", line 557, in compute_actions_from_input_dict
(PPO pid=19559)     return self._compute_action_helper(
(PPO pid=19559)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(PPO pid=19559)   File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/ray/rllib/utils/threading.py", line 24, in wrapper
(PPO pid=19559)     return func(self, *a, **k)
(PPO pid=19559)            ^^^^^^^^^^^^^^^^^^^
(PPO pid=19559)   File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/ray/rllib/policy/torch_policy_v2.py", line 1260, in _compute_action_helper
(PPO pid=19559)     dist_inputs, state_out = self.model(input_dict, state_batches, seq_lens)
(PPO pid=19559)                              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(PPO pid=19559)   File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/ray/rllib/models/modelv2.py", line 255, in __call__
(PPO pid=19559)     res = self.forward(restored, state or [], seq_lens)
(PPO pid=19559)           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(PPO pid=19559)   File "/Users/federicatonti/Desktop/KTH_work/2D_obstacles/PPO/navigation/untitled2.py", line 67, in forward
(PPO pid=19559)     x, new_state = self.lstm(x.unsqueeze(0), state)
(PPO pid=19559)                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(PPO pid=19559)   File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
(PPO pid=19559)     return self._call_impl(*args, **kwargs)
(PPO pid=19559)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(PPO pid=19559)   File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
(PPO pid=19559)     return forward_call(*args, **kwargs)
(PPO pid=19559)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(PPO pid=19559)   File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/torch/nn/modules/rnn.py", line 874, in forward
(PPO pid=19559)     self.check_forward_args(input, hx, batch_sizes)
(PPO pid=19559)   File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/torch/nn/modules/rnn.py", line 790, in check_forward_args
(PPO pid=19559)     self.check_hidden_size(hidden[0], self.get_expected_hidden_size(input, batch_sizes),
(PPO pid=19559)   File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/torch/nn/modules/rnn.py", line 259, in check_hidden_size
(PPO pid=19559)     raise RuntimeError(msg.format(expected_hidden_size, list(hx.size())))
(PPO pid=19559) RuntimeError: Expected hidden[0] size (1, 1, 32), got [32, 1, 32]