Related Questions: #343 and #9071
My model is like:
import ray
import numpy as np
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.torch.recurrent_net import RecurrentNetwork as TorchRNN
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_torch
torch, nn = try_import_torch()
class TorchRNNModel(TorchRNN, nn.Module):
def __init__(self,
obs_space,
action_space,
num_outputs,
model_config,
name,
lstm_state_size=128,
):
nn.Module.__init__(self)
super().__init__(obs_space, action_space, num_outputs, model_config, name)
# obs_space 8+3
obs = 8
obs_r = 3
self.lstm_state_size = lstm_state_size
self.fc_obs = nn.Linear(obs, 64)
self.fc_obs_r = nn.Linear(obs_r, 32)
self.norm_obs = nn.LayerNorm(64)
self.norm_obs_r = nn.LayerNorm(32)
self.fc_cat = nn.Linear(64+32, 128)
self.norm_cat = nn.LayerNorm(128)
self.lstm = nn.LSTM(self.lstm_state_size, self.lstm_state_size, batch_first=True)
self.action_branch = nn.Linear(self.lstm_state_size, num_outputs)
self.value_branch = nn.Linear(self.lstm_state_size, 1)
# Holds the current "base" output (before logits layer).
self._features = None
@override(TorchRNN)
def forward_rnn(self, inputs, state, seq_lens):
"""
input_dict (dict) – dictionary of input tensors, including “obs”, “obs_flat”, “prev_action”, “prev_reward”, “is_training”, “eps_id”, “agent_id”, “infos”, and “t”.
state (list) – list of state tensors with sizes matching those returned by get_initial_state + the batch dimension
seq_lens (Tensor) – 1d tensor holding input sequence lengths
"""
obs = inputs[0][0][:8].float().unsqueeze(0).unsqueeze(0)
obs_r = inputs[0][0][8:].float().unsqueeze(0).unsqueeze(0)
obs = nn.functional.relu(self.norm_obs(self.fc_obs(obs)))
obs_r = nn.functional.relu(self.norm_obs_r(self.fc_obs_r(obs_r)))
cat = torch.cat([obs, obs_r], 2)
cat = nn.functional.relu(self.norm_cat(self.fc_cat(cat)))
self._features, [h, c] = self.lstm(cat, [torch.unsqueeze(state[0], 0),
torch.unsqueeze(state[1], 0)])
action_out = self.action_branch(self._features)
return action_out, [torch.squeeze(h, 0), torch.squeeze(c, 0)]
@override(ModelV2)
def value_function(self):
assert self._features is not None, "must call forward() first"
return torch.reshape(self.value_branch(self._features), [-1])
@override(ModelV2)
def get_initial_state(self):
# TODO: (sven): Get rid of `get_initial_state` once Trajectory
# View API is supported across all of RLlib.
# Place hidden states on same device as model.
h = [
self.fc_cat.weight.new(1, self.lstm_state_size).zero_().squeeze(0),
self.fc_cat.weight.new(1, self.lstm_state_size).zero_().squeeze(0)
]
return h
and error is like:
(pid=90398) 2021-07-12 19:52:45,784 ERROR worker.py:418 -- Exception raised in creation task: The actor died because of an error raised in its creation task, ray::RolloutWorker.__init__() (pid=90398, ip=192.168.70.128)
(pid=90398) File "python/ray/_raylet.pyx", line 501, in ray._raylet.execute_task
(pid=90398) File "python/ray/_raylet.pyx", line 451, in ray._raylet.execute_task.function_executor
(pid=90398) File "/home/zhaoyong/.conda/envs/py37/lib/python3.7/site-packages/ray/_private/function_manager.py", line 563, in actor_method_executor
(pid=90398) return method(__ray_actor, *args, **kwargs)
(pid=90398) File "/home/zhaoyong/.conda/envs/py37/lib/python3.7/site-packages/ray/rllib/evaluation/rollout_worker.py", line 537, in __init__
(pid=90398) policy_dict, policy_config)
(pid=90398) File "/home/zhaoyong/.conda/envs/py37/lib/python3.7/site-packages/ray/rllib/evaluation/rollout_worker.py", line 1196, in _build_policy_map
(pid=90398) policy_map[name] = cls(obs_space, act_space, merged_conf)
(pid=90398) File "/home/zhaoyong/.conda/envs/py37/lib/python3.7/site-packages/ray/rllib/policy/policy_template.py", line 281, in __init__
(pid=90398) stats_fn=stats_fn,
(pid=90398) File "/home/zhaoyong/.conda/envs/py37/lib/python3.7/site-packages/ray/rllib/policy/policy.py", line 623, in _initialize_loss_from_dummy_batch
(pid=90398) self._dummy_batch, explore=False)
(pid=90398) File "/home/zhaoyong/.conda/envs/py37/lib/python3.7/site-packages/ray/rllib/policy/torch_policy.py", line 262, in compute_actions_from_input_dict
(pid=90398) seq_lens, explore, timestep)
(pid=90398) File "/home/zhaoyong/.conda/envs/py37/lib/python3.7/site-packages/ray/rllib/utils/threading.py", line 21, in wrapper
(pid=90398) return func(self, *a, **k)
(pid=90398) File "/home/zhaoyong/.conda/envs/py37/lib/python3.7/site-packages/ray/rllib/policy/torch_policy.py", line 326, in _compute_action_helper
(pid=90398) seq_lens)
(pid=90398) File "/home/zhaoyong/.conda/envs/py37/lib/python3.7/site-packages/ray/rllib/models/modelv2.py", line 234, in __call__
(pid=90398) res = self.forward(restored, state or [], seq_lens)
(pid=90398) File "/home/zhaoyong/.conda/envs/py37/lib/python3.7/site-packages/ray/rllib/models/torch/recurrent_net.py", line 83, in forward
(pid=90398) output, new_state = self.forward_rnn(inputs, state, seq_lens)
(pid=90398) File "/home/zhaoyong/Codes/Python/RRL/ray_module/custom_model.py", line 61, in forward_rnn
(pid=90398) torch.unsqueeze(state[1], 0)])
(pid=90398) File "/home/zhaoyong/.conda/envs/py37/lib/python3.7/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
(pid=90398) result = self.forward(*input, **kwargs)
(pid=90398) File "/home/zhaoyong/.conda/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py", line 574, in forward
(pid=90398) self.check_forward_args(input, hx, batch_sizes)
(pid=90398) File "/home/zhaoyong/.conda/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py", line 529, in check_forward_args
(pid=90398) 'Expected hidden[0] size {}, got {}')
(pid=90398) File "/home/zhaoyong/.conda/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py", line 195, in check_hidden_size
(pid=90398) raise RuntimeError(msg.format(expected_hidden_size, tuple(hx.size())))
(pid=90398) RuntimeError: Expected hidden[0] size (1, 1, 128), got (1, 32, 128)
(pid=90398)
Traceback (most recent call last):
File "/home/zhaoyong/Codes/Python/RRL/ray_module/ppo_agent.py", line 26, in <module>
agent = PPOTrainer(config, 'KheperaEnv-v0')
File "/home/zhaoyong/.conda/envs/py37/lib/python3.7/site-packages/ray/rllib/agents/trainer_template.py", line 123, in __init__
Trainer.__init__(self, config, env, logger_creator)
File "/home/zhaoyong/.conda/envs/py37/lib/python3.7/site-packages/ray/rllib/agents/trainer.py", line 548, in __init__
super().__init__(config, logger_creator)
File "/home/zhaoyong/.conda/envs/py37/lib/python3.7/site-packages/ray/tune/trainable.py", line 98, in __init__
self.setup(copy.deepcopy(self.config))
File "/home/zhaoyong/.conda/envs/py37/lib/python3.7/site-packages/ray/rllib/agents/trainer.py", line 709, in setup
self._init(self.config, self.env_creator)
File "/home/zhaoyong/.conda/envs/py37/lib/python3.7/site-packages/ray/rllib/agents/trainer_template.py", line 155, in _init
num_workers=self.config["num_workers"])
File "/home/zhaoyong/.conda/envs/py37/lib/python3.7/site-packages/ray/rllib/agents/trainer.py", line 797, in _make_workers
logdir=self.logdir)
File "/home/zhaoyong/.conda/envs/py37/lib/python3.7/site-packages/ray/rllib/evaluation/worker_set.py", line 83, in __init__
lambda p, pid: (pid, p.observation_space, p.action_space)))
File "/home/zhaoyong/.conda/envs/py37/lib/python3.7/site-packages/ray/_private/client_mode_hook.py", line 62, in wrapper
return func(*args, **kwargs)
File "/home/zhaoyong/.conda/envs/py37/lib/python3.7/site-packages/ray/worker.py", line 1497, in get
raise value
ray.exceptions.RayActorError: The actor died because of an error raised in its creation task, ray::RolloutWorker.__init__() (pid=90399, ip=192.168.70.128)
File "python/ray/_raylet.pyx", line 501, in ray._raylet.execute_task
File "python/ray/_raylet.pyx", line 451, in ray._raylet.execute_task.function_executor
File "/home/zhaoyong/.conda/envs/py37/lib/python3.7/site-packages/ray/_private/function_manager.py", line 563, in actor_method_executor
return method(__ray_actor, *args, **kwargs)
File "/home/zhaoyong/.conda/envs/py37/lib/python3.7/site-packages/ray/rllib/evaluation/rollout_worker.py", line 537, in __init__
policy_dict, policy_config)
File "/home/zhaoyong/.conda/envs/py37/lib/python3.7/site-packages/ray/rllib/evaluation/rollout_worker.py", line 1196, in _build_policy_map
policy_map[name] = cls(obs_space, act_space, merged_conf)
File "/home/zhaoyong/.conda/envs/py37/lib/python3.7/site-packages/ray/rllib/policy/policy_template.py", line 281, in __init__
stats_fn=stats_fn,
File "/home/zhaoyong/.conda/envs/py37/lib/python3.7/site-packages/ray/rllib/policy/policy.py", line 623, in _initialize_loss_from_dummy_batch
self._dummy_batch, explore=False)
File "/home/zhaoyong/.conda/envs/py37/lib/python3.7/site-packages/ray/rllib/policy/torch_policy.py", line 262, in compute_actions_from_input_dict
seq_lens, explore, timestep)
File "/home/zhaoyong/.conda/envs/py37/lib/python3.7/site-packages/ray/rllib/utils/threading.py", line 21, in wrapper
return func(self, *a, **k)
File "/home/zhaoyong/.conda/envs/py37/lib/python3.7/site-packages/ray/rllib/policy/torch_policy.py", line 326, in _compute_action_helper
seq_lens)
File "/home/zhaoyong/.conda/envs/py37/lib/python3.7/site-packages/ray/rllib/models/modelv2.py", line 234, in __call__
res = self.forward(restored, state or [], seq_lens)
File "/home/zhaoyong/.conda/envs/py37/lib/python3.7/site-packages/ray/rllib/models/torch/recurrent_net.py", line 83, in forward
output, new_state = self.forward_rnn(inputs, state, seq_lens)
File "/home/zhaoyong/Codes/Python/RRL/ray_module/custom_model.py", line 61, in forward_rnn
torch.unsqueeze(state[1], 0)])
File "/home/zhaoyong/.conda/envs/py37/lib/python3.7/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/zhaoyong/.conda/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py", line 574, in forward
self.check_forward_args(input, hx, batch_sizes)
File "/home/zhaoyong/.conda/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py", line 529, in check_forward_args
'Expected hidden[0] size {}, got {}')
File "/home/zhaoyong/.conda/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py", line 195, in check_hidden_size
raise RuntimeError(msg.format(expected_hidden_size, tuple(hx.size())))
RuntimeError: Expected hidden[0] size (1, 1, 128), got (1, 32, 128)
(pid=90399) 2021-07-12 19:52:46,616 ERROR worker.py:418 -- Exception raised in creation task: The actor died because of an error raised in its creation task, ray::RolloutWorker.__init__() (pid=90399, ip=192.168.70.128)
(pid=90399) File "python/ray/_raylet.pyx", line 501, in ray._raylet.execute_task
(pid=90399) File "python/ray/_raylet.pyx", line 451, in ray._raylet.execute_task.function_executor
(pid=90399) File "/home/zhaoyong/.conda/envs/py37/lib/python3.7/site-packages/ray/_private/function_manager.py", line 563, in actor_method_executor
(pid=90399) return method(__ray_actor, *args, **kwargs)
(pid=90399) File "/home/zhaoyong/.conda/envs/py37/lib/python3.7/site-packages/ray/rllib/evaluation/rollout_worker.py", line 537, in __init__
(pid=90399) policy_dict, policy_config)
(pid=90399) File "/home/zhaoyong/.conda/envs/py37/lib/python3.7/site-packages/ray/rllib/evaluation/rollout_worker.py", line 1196, in _build_policy_map
(pid=90399) policy_map[name] = cls(obs_space, act_space, merged_conf)
(pid=90399) File "/home/zhaoyong/.conda/envs/py37/lib/python3.7/site-packages/ray/rllib/policy/policy_template.py", line 281, in __init__
(pid=90399) stats_fn=stats_fn,
(pid=90399) File "/home/zhaoyong/.conda/envs/py37/lib/python3.7/site-packages/ray/rllib/policy/policy.py", line 623, in _initialize_loss_from_dummy_batch
(pid=90399) self._dummy_batch, explore=False)
(pid=90399) File "/home/zhaoyong/.conda/envs/py37/lib/python3.7/site-packages/ray/rllib/policy/torch_policy.py", line 262, in compute_actions_from_input_dict
(pid=90399) seq_lens, explore, timestep)
(pid=90399) File "/home/zhaoyong/.conda/envs/py37/lib/python3.7/site-packages/ray/rllib/utils/threading.py", line 21, in wrapper
(pid=90399) return func(self, *a, **k)
(pid=90399) File "/home/zhaoyong/.conda/envs/py37/lib/python3.7/site-packages/ray/rllib/policy/torch_policy.py", line 326, in _compute_action_helper
(pid=90399) seq_lens)
(pid=90399) File "/home/zhaoyong/.conda/envs/py37/lib/python3.7/site-packages/ray/rllib/models/modelv2.py", line 234, in __call__
(pid=90399) res = self.forward(restored, state or [], seq_lens)
(pid=90399) File "/home/zhaoyong/.conda/envs/py37/lib/python3.7/site-packages/ray/rllib/models/torch/recurrent_net.py", line 83, in forward
(pid=90399) output, new_state = self.forward_rnn(inputs, state, seq_lens)
(pid=90399) File "/home/zhaoyong/Codes/Python/RRL/ray_module/custom_model.py", line 61, in forward_rnn
(pid=90399) torch.unsqueeze(state[1], 0)])
(pid=90399) File "/home/zhaoyong/.conda/envs/py37/lib/python3.7/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
(pid=90399) result = self.forward(*input, **kwargs)
(pid=90399) File "/home/zhaoyong/.conda/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py", line 574, in forward
(pid=90399) self.check_forward_args(input, hx, batch_sizes)
(pid=90399) File "/home/zhaoyong/.conda/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py", line 529, in check_forward_args
(pid=90399) 'Expected hidden[0] size {}, got {}')
(pid=90399) File "/home/zhaoyong/.conda/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py", line 195, in check_hidden_size
(pid=90399) raise RuntimeError(msg.format(expected_hidden_size, tuple(hx.size())))
(pid=90399) RuntimeError: Expected hidden[0] size (1, 1, 128), got (1, 32, 128)
(pid=90399)
Process finished with exit code 1
As stated in the previous advice, I have updated my ray version to 1.4.1.
But there is still a problem with this state, how can I solve it?