I have a problem when trying to implement my custom model. The impression I have is that it calls not only my custom model but someway calls also the methods in the source LSTM Wrapper which are not overridden. I paste here my custom model, and the errors I get
class CustomTorchModel(TorchModelV2):
def __init__(self, obs_space, action_space, num_outputs, model_config, name):
nn.Module.__init__(self)
super(CustomTorchModel, self).__init__(obs_space, action_space, num_outputs, model_config, name)
# Configuration
self.obs_size = obs_space.shape[0]
self.action_size = action_space.n if hasattr(action_space, "n") else action_space.shape[0]
self.input_size = self.obs_size + self.action_size + 1 # obs + last action + last reward
self.lstm_hidden_size = model_config["lstm_cell_size"]
self.num_layers = 1
# Fully Connected Layers
self.fc1 = nn.Linear(self.input_size, 64)
self.fc2 = nn.Linear(64, 15)
# LSTM Configuration
self.lstm = nn.LSTM(15, self.lstm_hidden_size, batch_first=True)
# Output Layers
self.policy_mean = nn.Linear(self.lstm_hidden_size, self.action_size)
self.policy_std = nn.Linear(self.lstm_hidden_size, self.action_size)
self.value_head = nn.Linear(self.lstm_hidden_size, 1)
# Initial Value Placeholder
self._value_out = None
# View Requirements for RNNs
if model_config["lstm_use_prev_action"]:
self.view_requirements[SampleBatch.PREV_ACTIONS] = ViewRequirement(
"actions", shift=-1, space=action_space
)
if model_config["lstm_use_prev_reward"]:
self.view_requirements[SampleBatch.PREV_REWARDS] = ViewRequirement(
"rewards", shift=-1
)
@override(TorchModelV2)
def forward(self, input_dict: Dict[str, TensorType], state: List[TensorType], seq_lens: TensorType) -> Tuple[TensorType, List[TensorType]]:
print(f"Forward call with state: {state} and seq_lens: {seq_lens}")
print(f"Shape of obs: {input_dict['obs'].shape}")
print(f"Shape of prev_actions: {input_dict['prev_actions'].shape}")
print(f"Shape of prev_rewards: {input_dict['prev_rewards'].shape}")
# Concatenate Inputs
x = torch.cat([
input_dict["obs"],
input_dict["prev_actions"],
input_dict["prev_rewards"].unsqueeze(-1)
], dim=1)
assert x.shape[-1] == self.input_size, f"Input size mismatch: Expected {self.input_size}, got {x.shape[-1]}"
print(f"Concatenated input x shape: {x.shape}")
# Fully Connected Layers
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
print(f"Post-FC layer shape: {x.shape}")
x = x.unsqueeze(1) # For LSTM compatibility
print(f"Input to LSTM shape: {x.shape}")
return self.forward_rnn(x, state, seq_lens)
@override(RecurrentNetwork)
def forward_rnn(self, inputs: TensorType, state: List[TensorType], seq_lens: TensorType) -> Tuple[TensorType, List[TensorType]]:
print(f"Shape of inputs before LSTM: {inputs.shape}")
print(f"State shapes: {[s.shape for s in state]}")
if state is None or len(state) == 0:
h0 = torch.zeros(self.num_layers, inputs.size(0), self.lstm_hidden_size, device=inputs.device)
c0 = torch.zeros(self.num_layers, inputs.size(0), self.lstm_hidden_size, device=inputs.device)
print("Generated new initial states for LSTM")
else:
h0, c0 = state[0],state[1]
x, (hn, cn) = self.lstm(inputs, (h0, c0))
print(f"Output from LSTM x shape: {x.shape}, hn shape: {hn.shape}, cn shape: {cn.shape}")
x = x.squeeze(1) # Remove sequence length dimension
print(f"Squeezed output shape: {x.shape}")
action_mean = self.policy_mean(x)
action_std = torch.exp(self.policy_std(x))
print(f"Action mean shape: {action_mean.shape}, Action std shape: {action_std.shape}")
self._value_out = self.value_head(x)
print(f"Value output shape: {self._value_out.shape}")
return torch.cat([action_mean, action_std], dim=-1), [hn, cn]
@override(TorchModelV2)
def value_function(self):
assert self._value_out is not None, "Must call forward() first"
return self._value_out.squeeze(1)
def get_initial_state(self) -> Union[List[np.ndarray], List[TensorType]]:
return [torch.zeros(self.num_layers, 1, self.lstm_hidden_size),
torch.zeros(self.num_layers, 1, self.lstm_hidden_size)]
Here the errors and the outputs of my printings:
2024-05-09 14:16:37,584 ERROR tune_controller.py:1332 -- Trial task failed for trial PPO_FlowFieldNavEnv-v0_f861a_00000
Traceback (most recent call last):
File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/ray/air/execution/_internal/event_manager.py", line 110, in resolve_future
result = ray.get(future)
^^^^^^^^^^^^^^^
File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/ray/_private/auto_init_hook.py", line 21, in auto_init_wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/ray/_private/client_mode_hook.py", line 103, in wrapper
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/ray/_private/worker.py", line 2667, in get
values, debugger_breakpoint = worker.get_objects(object_refs, timeout=timeout)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/ray/_private/worker.py", line 866, in get_objects
raise value
ray.exceptions.RayActorError: The actor died because of an error raised in its creation task, ray::PPO.__init__() (pid=12275, ip=127.0.0.1, actor_id=15c768a5baccc2f38a7a7dce01000000, repr=PPO)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/ray/rllib/algorithms/algorithm.py", line 533, in __init__
super().__init__(
File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/ray/tune/trainable/trainable.py", line 161, in __init__
self.setup(copy.deepcopy(self.config))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/ray/rllib/algorithms/algorithm.py", line 631, in setup
self.workers = WorkerSet(
^^^^^^^^^^
File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/ray/rllib/evaluation/worker_set.py", line 159, in __init__
self._setup(
File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/ray/rllib/evaluation/worker_set.py", line 250, in _setup
self._local_worker = self._make_worker(
^^^^^^^^^^^^^^^^^^
File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/ray/rllib/evaluation/worker_set.py", line 1016, in _make_worker
worker = cls(
^^^^
File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/ray/rllib/evaluation/rollout_worker.py", line 535, in __init__
self._update_policy_map(policy_dict=self.policy_dict)
File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/ray/rllib/evaluation/rollout_worker.py", line 1743, in _update_policy_map
self._build_policy_map(
File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/ray/rllib/evaluation/rollout_worker.py", line 1854, in _build_policy_map
new_policy = create_policy_for_framework(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/ray/rllib/utils/policy.py", line 141, in create_policy_for_framework
return policy_class(observation_space, action_space, merged_config)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/ray/rllib/algorithms/ppo/ppo_torch_policy.py", line 64, in __init__
self._initialize_loss_from_dummy_batch()
File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/ray/rllib/policy/policy.py", line 1396, in _initialize_loss_from_dummy_batch
actions, state_outs, extra_outs = self.compute_actions_from_input_dict(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/ray/rllib/policy/torch_policy_v2.py", line 557, in compute_actions_from_input_dict
return self._compute_action_helper(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/ray/rllib/utils/threading.py", line 24, in wrapper
return func(self, *a, **k)
^^^^^^^^^^^^^^^^^^^
File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/ray/rllib/policy/torch_policy_v2.py", line 1260, in _compute_action_helper
dist_inputs, state_out = self.model(input_dict, state_batches, seq_lens)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/ray/rllib/models/modelv2.py", line 255, in __call__
res = self.forward(restored, state or [], seq_lens)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/ray/rllib/models/torch/recurrent_net.py", line 219, in forward
wrapped_out, _ = self._wrapped_forward(input_dict, [], None)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/federicatonti/Desktop/KTH_work/2D_obstacles/PPO/navigation/untitled2.py", line 167, in forward
return self.forward_rnn(x, state, seq_lens)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/ray/rllib/models/torch/recurrent_net.py", line 274, in forward_rnn
inputs, [torch.unsqueeze(state[0], 0), torch.unsqueeze(state[1], 0)]
~~~~~^^^
IndexError: list index out of range
2024-05-09 14:16:37,588 INFO tune.py:1016 -- Wrote the latest version of all result files and experiment state to '/Users/federicatonti/ray_results/PPO_Flow_field_nav_' in 0.0019s.
2024-05-09 14:16:37,589 ERROR tune.py:1044 -- Trials did not complete: [PPO_FlowFieldNavEnv-v0_f861a_00000]
2024-05-09 14:16:37,597 WARNING experiment_analysis.py:190 -- Failed to fetch metrics for 1 trial(s):
- PPO_FlowFieldNavEnv-v0_f861a_00000: FileNotFoundError('Could not fetch metrics for PPO_FlowFieldNavEnv-v0_f861a_00000: both result.json and progress.csv were not found at /Users/federicatonti/ray_results/PPO_Flow_field_nav_/PPO_FlowFieldNavEnv-v0_f861a_00000_0_2024-05-09_14-16-31')
Trial PPO_FlowFieldNavEnv-v0_f861a_00000 errored after 0 iterations at 2024-05-09 14:16:37. Total running time: 5s
Error file: /tmp/ray/session_2024-05-09_14-16-29_794196_11064/artifacts/2024-05-09_14-16-31/PPO_Flow_field_nav_/driver_artifacts/PPO_FlowFieldNavEnv-v0_f861a_00000_0_2024-05-09_14-16-31/error.txt
Trial status: 1 ERROR
Current time: 2024-05-09 14:16:37. Total running time: 5s
Logical resource usage: 0/12 CPUs, 0/0 GPUs
+-----------------------------------------------+
| Trial name status |
+-----------------------------------------------+
| PPO_FlowFieldNavEnv-v0_f861a_00000 ERROR |
+-----------------------------------------------+
Number of errored trials: 1
+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
| Trial name # failures error file |
+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
| PPO_FlowFieldNavEnv-v0_f861a_00000 1 /tmp/ray/session_2024-05-09_14-16-29_794196_11064/artifacts/2024-05-09_14-16-31/PPO_Flow_field_nav_/driver_artifacts/PPO_FlowFieldNavEnv-v0_f861a_00000_0_2024-05-09_14-16-31/error.txt |
+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
(PPO pid=12275) Forward call with state: [] and seq_lens: None
(PPO pid=12275) Shape of obs: torch.Size([32, 12])
(PPO pid=12275) Shape of prev_actions: torch.Size([32, 2])
(PPO pid=12275) Shape of prev_rewards: torch.Size([32])
(PPO pid=12275) Concatenated input x shape: torch.Size([32, 15])
(PPO pid=12275) Post-FC layer shape: torch.Size([32, 15])
(PPO pid=12275) Input to LSTM shape: torch.Size([32, 1, 15])
(PPO pid=12275) Shape of inputs before LSTM: torch.Size([32, 1, 15])
(PPO pid=12275) State shapes: []
(PPO pid=12275) Exception raised in creation task: The actor died because of an error raised in its creation task, ray::PPO.__init__() (pid=12275, ip=127.0.0.1, actor_id=15c768a5baccc2f38a7a7dce01000000, repr=PPO)
(PPO pid=12275) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(PPO pid=12275) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(PPO pid=12275) File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/ray/rllib/algorithms/algorithm.py", line 533, in __init__
(PPO pid=12275) super().__init__(
(PPO pid=12275) File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/ray/tune/trainable/trainable.py", line 161, in __init__
(PPO pid=12275) self.setup(copy.deepcopy(self.config))
(PPO pid=12275) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(PPO pid=12275) File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/ray/rllib/algorithms/algorithm.py", line 631, in setup
(PPO pid=12275) self.workers = WorkerSet(
(PPO pid=12275) ^^^^^^^^^^
(PPO pid=12275) File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/ray/rllib/evaluation/worker_set.py", line 159, in __init__
(PPO pid=12275) self._setup(
(PPO pid=12275) File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/ray/rllib/evaluation/worker_set.py", line 250, in _setup
(PPO pid=12275) self._local_worker = self._make_worker(
(PPO pid=12275) ^^^^^^^^^^^^^^^^^^
(PPO pid=12275) File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/ray/rllib/evaluation/worker_set.py", line 1016, in _make_worker
(PPO pid=12275) worker = cls(
(PPO pid=12275) ^^^^
(PPO pid=12275) File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/ray/rllib/evaluation/rollout_worker.py", line 535, in __init__
(PPO pid=12275) self._update_policy_map(policy_dict=self.policy_dict)
(PPO pid=12275) File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/ray/rllib/evaluation/rollout_worker.py", line 1743, in _update_policy_map
(PPO pid=12275) self._build_policy_map(
(PPO pid=12275) File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/ray/rllib/evaluation/rollout_worker.py", line 1854, in _build_policy_map
(PPO pid=12275) new_policy = create_policy_for_framework(
(PPO pid=12275) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(PPO pid=12275) File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/ray/rllib/utils/policy.py", line 141, in create_policy_for_framework
(PPO pid=12275) return policy_class(observation_space, action_space, merged_config)
(PPO pid=12275) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(PPO pid=12275) File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/ray/rllib/algorithms/ppo/ppo_torch_policy.py", line 64, in __init__
(PPO pid=12275) self._initialize_loss_from_dummy_batch()
(PPO pid=12275) File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/ray/rllib/policy/policy.py", line 1396, in _initialize_loss_from_dummy_batch
(PPO pid=12275) actions, state_outs, extra_outs = self.compute_actions_from_input_dict(
(PPO pid=12275) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(PPO pid=12275) File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/ray/rllib/policy/torch_policy_v2.py", line 557, in compute_actions_from_input_dict
(PPO pid=12275) return self._compute_action_helper(
(PPO pid=12275) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(PPO pid=12275) File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/ray/rllib/utils/threading.py", line 24, in wrapper
(PPO pid=12275) return func(self, *a, **k)
(PPO pid=12275) ^^^^^^^^^^^^^^^^^^^
(PPO pid=12275) File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/ray/rllib/policy/torch_policy_v2.py", line 1260, in _compute_action_helper
(PPO pid=12275) dist_inputs, state_out = self.model(input_dict, state_batches, seq_lens)
(PPO pid=12275) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(PPO pid=12275) File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/ray/rllib/models/modelv2.py", line 255, in __call__
(PPO pid=12275) res = self.forward(restored, state or [], seq_lens)
(PPO pid=12275) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(PPO pid=12275) File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/ray/rllib/models/torch/recurrent_net.py", line 219, in forward
(PPO pid=12275) wrapped_out, _ = self._wrapped_forward(input_dict, [], None)
(PPO pid=12275) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(PPO pid=12275) File "/Users/federicatonti/Desktop/KTH_work/2D_obstacles/PPO/navigation/untitled2.py", line 167, in forward
(PPO pid=12275) return self.forward_rnn(x, state, seq_lens)
(PPO pid=12275) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(PPO pid=12275) File "/Users/federicatonti/miniconda3/envs/tf2/lib/python3.11/site-packages/ray/rllib/models/torch/recurrent_net.py", line 274, in forward_rnn
(PPO pid=12275) inputs, [torch.unsqueeze(state[0], 0), torch.unsqueeze(state[1], 0)]
(PPO pid=12275) ~~~~~^^^
(PPO pid=12275) IndexError: list index out of range
And here the shapes of the tensors:
PPO pid=12275) Forward call with state: [] and seq_lens: None
(PPO pid=12275) Shape of obs: torch.Size([32, 12])
(PPO pid=12275) Shape of prev_actions: torch.Size([32, 2])
(PPO pid=12275) Shape of prev_rewards: torch.Size([32])
(PPO pid=12275) Concatenated input x shape: torch.Size([32, 15])
(PPO pid=12275) Post-FC layer shape: torch.Size([32, 15])
(PPO pid=12275) Input to LSTM shape: torch.Size([32, 1, 15])
(PPO pid=12275) Shape of inputs before LSTM: torch.Size([32, 1, 15])
(PPO pid=12275) State shapes: []
I donβt understand why the state is empty