State shapes incorrect using custom model (TorchModelV2) (PPO)

Related Questions: #343 and #9071

My model is like:

import ray
import numpy as np

from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.torch.recurrent_net import RecurrentNetwork as TorchRNN
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_torch

torch, nn = try_import_torch()


class TorchRNNModel(TorchRNN, nn.Module):
    def __init__(self,
                 obs_space,
                 action_space,
                 num_outputs,
                 model_config,
                 name,
                 lstm_state_size=128,
                ):
        nn.Module.__init__(self)
        super().__init__(obs_space, action_space, num_outputs, model_config, name)

        # obs_space 8+3 
        obs = 8
        obs_r = 3
        self.lstm_state_size = lstm_state_size

        self.fc_obs = nn.Linear(obs, 64)
        self.fc_obs_r = nn.Linear(obs_r, 32)
        self.norm_obs = nn.LayerNorm(64)
        self.norm_obs_r = nn.LayerNorm(32)

        self.fc_cat = nn.Linear(64+32, 128)
        self.norm_cat = nn.LayerNorm(128)
        self.lstm = nn.LSTM(self.lstm_state_size, self.lstm_state_size, batch_first=True)
        self.action_branch = nn.Linear(self.lstm_state_size, num_outputs)
        self.value_branch = nn.Linear(self.lstm_state_size, 1)
        # Holds the current "base" output (before logits layer).
        self._features = None

    @override(TorchRNN)
    def forward_rnn(self, inputs, state, seq_lens):
        """
        input_dict (dict) – dictionary of input tensors, including “obs”, “obs_flat”, “prev_action”, “prev_reward”, “is_training”, “eps_id”, “agent_id”, “infos”, and “t”.
        state (list) – list of state tensors with sizes matching those returned by get_initial_state + the batch dimension
        seq_lens (Tensor) – 1d tensor holding input sequence lengths
        """
        obs = inputs[0][0][:8].float().unsqueeze(0).unsqueeze(0)
        obs_r = inputs[0][0][8:].float().unsqueeze(0).unsqueeze(0)
        obs = nn.functional.relu(self.norm_obs(self.fc_obs(obs)))
        obs_r = nn.functional.relu(self.norm_obs_r(self.fc_obs_r(obs_r)))
        cat = torch.cat([obs, obs_r], 2)
        cat = nn.functional.relu(self.norm_cat(self.fc_cat(cat)))

        self._features, [h, c] = self.lstm(cat, [torch.unsqueeze(state[0], 0), 
                                                 torch.unsqueeze(state[1], 0)])
        action_out = self.action_branch(self._features)
        return action_out, [torch.squeeze(h, 0), torch.squeeze(c, 0)]

    @override(ModelV2)
    def value_function(self):
        assert self._features is not None, "must call forward() first"
        return torch.reshape(self.value_branch(self._features), [-1])

    @override(ModelV2)
    def get_initial_state(self):
        # TODO: (sven): Get rid of `get_initial_state` once Trajectory
        #  View API is supported across all of RLlib.
        # Place hidden states on same device as model.
        h = [
            self.fc_cat.weight.new(1, self.lstm_state_size).zero_().squeeze(0),
            self.fc_cat.weight.new(1, self.lstm_state_size).zero_().squeeze(0)
        ]
        return h

and error is like:

(pid=90398) 2021-07-12 19:52:45,784	ERROR worker.py:418 -- Exception raised in creation task: The actor died because of an error raised in its creation task, ray::RolloutWorker.__init__() (pid=90398, ip=192.168.70.128)
(pid=90398)   File "python/ray/_raylet.pyx", line 501, in ray._raylet.execute_task
(pid=90398)   File "python/ray/_raylet.pyx", line 451, in ray._raylet.execute_task.function_executor
(pid=90398)   File "/home/zhaoyong/.conda/envs/py37/lib/python3.7/site-packages/ray/_private/function_manager.py", line 563, in actor_method_executor
(pid=90398)     return method(__ray_actor, *args, **kwargs)
(pid=90398)   File "/home/zhaoyong/.conda/envs/py37/lib/python3.7/site-packages/ray/rllib/evaluation/rollout_worker.py", line 537, in __init__
(pid=90398)     policy_dict, policy_config)
(pid=90398)   File "/home/zhaoyong/.conda/envs/py37/lib/python3.7/site-packages/ray/rllib/evaluation/rollout_worker.py", line 1196, in _build_policy_map
(pid=90398)     policy_map[name] = cls(obs_space, act_space, merged_conf)
(pid=90398)   File "/home/zhaoyong/.conda/envs/py37/lib/python3.7/site-packages/ray/rllib/policy/policy_template.py", line 281, in __init__
(pid=90398)     stats_fn=stats_fn,
(pid=90398)   File "/home/zhaoyong/.conda/envs/py37/lib/python3.7/site-packages/ray/rllib/policy/policy.py", line 623, in _initialize_loss_from_dummy_batch
(pid=90398)     self._dummy_batch, explore=False)
(pid=90398)   File "/home/zhaoyong/.conda/envs/py37/lib/python3.7/site-packages/ray/rllib/policy/torch_policy.py", line 262, in compute_actions_from_input_dict
(pid=90398)     seq_lens, explore, timestep)
(pid=90398)   File "/home/zhaoyong/.conda/envs/py37/lib/python3.7/site-packages/ray/rllib/utils/threading.py", line 21, in wrapper
(pid=90398)     return func(self, *a, **k)
(pid=90398)   File "/home/zhaoyong/.conda/envs/py37/lib/python3.7/site-packages/ray/rllib/policy/torch_policy.py", line 326, in _compute_action_helper
(pid=90398)     seq_lens)
(pid=90398)   File "/home/zhaoyong/.conda/envs/py37/lib/python3.7/site-packages/ray/rllib/models/modelv2.py", line 234, in __call__
(pid=90398)     res = self.forward(restored, state or [], seq_lens)
(pid=90398)   File "/home/zhaoyong/.conda/envs/py37/lib/python3.7/site-packages/ray/rllib/models/torch/recurrent_net.py", line 83, in forward
(pid=90398)     output, new_state = self.forward_rnn(inputs, state, seq_lens)
(pid=90398)   File "/home/zhaoyong/Codes/Python/RRL/ray_module/custom_model.py", line 61, in forward_rnn
(pid=90398)     torch.unsqueeze(state[1], 0)])
(pid=90398)   File "/home/zhaoyong/.conda/envs/py37/lib/python3.7/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
(pid=90398)     result = self.forward(*input, **kwargs)
(pid=90398)   File "/home/zhaoyong/.conda/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py", line 574, in forward
(pid=90398)     self.check_forward_args(input, hx, batch_sizes)
(pid=90398)   File "/home/zhaoyong/.conda/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py", line 529, in check_forward_args
(pid=90398)     'Expected hidden[0] size {}, got {}')
(pid=90398)   File "/home/zhaoyong/.conda/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py", line 195, in check_hidden_size
(pid=90398)     raise RuntimeError(msg.format(expected_hidden_size, tuple(hx.size())))
(pid=90398) RuntimeError: Expected hidden[0] size (1, 1, 128), got (1, 32, 128)
(pid=90398) 
Traceback (most recent call last):
  File "/home/zhaoyong/Codes/Python/RRL/ray_module/ppo_agent.py", line 26, in <module>
    agent = PPOTrainer(config, 'KheperaEnv-v0')
  File "/home/zhaoyong/.conda/envs/py37/lib/python3.7/site-packages/ray/rllib/agents/trainer_template.py", line 123, in __init__
    Trainer.__init__(self, config, env, logger_creator)
  File "/home/zhaoyong/.conda/envs/py37/lib/python3.7/site-packages/ray/rllib/agents/trainer.py", line 548, in __init__
    super().__init__(config, logger_creator)
  File "/home/zhaoyong/.conda/envs/py37/lib/python3.7/site-packages/ray/tune/trainable.py", line 98, in __init__
    self.setup(copy.deepcopy(self.config))
  File "/home/zhaoyong/.conda/envs/py37/lib/python3.7/site-packages/ray/rllib/agents/trainer.py", line 709, in setup
    self._init(self.config, self.env_creator)
  File "/home/zhaoyong/.conda/envs/py37/lib/python3.7/site-packages/ray/rllib/agents/trainer_template.py", line 155, in _init
    num_workers=self.config["num_workers"])
  File "/home/zhaoyong/.conda/envs/py37/lib/python3.7/site-packages/ray/rllib/agents/trainer.py", line 797, in _make_workers
    logdir=self.logdir)
  File "/home/zhaoyong/.conda/envs/py37/lib/python3.7/site-packages/ray/rllib/evaluation/worker_set.py", line 83, in __init__
    lambda p, pid: (pid, p.observation_space, p.action_space)))
  File "/home/zhaoyong/.conda/envs/py37/lib/python3.7/site-packages/ray/_private/client_mode_hook.py", line 62, in wrapper
    return func(*args, **kwargs)
  File "/home/zhaoyong/.conda/envs/py37/lib/python3.7/site-packages/ray/worker.py", line 1497, in get
    raise value
ray.exceptions.RayActorError: The actor died because of an error raised in its creation task, ray::RolloutWorker.__init__() (pid=90399, ip=192.168.70.128)
  File "python/ray/_raylet.pyx", line 501, in ray._raylet.execute_task
  File "python/ray/_raylet.pyx", line 451, in ray._raylet.execute_task.function_executor
  File "/home/zhaoyong/.conda/envs/py37/lib/python3.7/site-packages/ray/_private/function_manager.py", line 563, in actor_method_executor
    return method(__ray_actor, *args, **kwargs)
  File "/home/zhaoyong/.conda/envs/py37/lib/python3.7/site-packages/ray/rllib/evaluation/rollout_worker.py", line 537, in __init__
    policy_dict, policy_config)
  File "/home/zhaoyong/.conda/envs/py37/lib/python3.7/site-packages/ray/rllib/evaluation/rollout_worker.py", line 1196, in _build_policy_map
    policy_map[name] = cls(obs_space, act_space, merged_conf)
  File "/home/zhaoyong/.conda/envs/py37/lib/python3.7/site-packages/ray/rllib/policy/policy_template.py", line 281, in __init__
    stats_fn=stats_fn,
  File "/home/zhaoyong/.conda/envs/py37/lib/python3.7/site-packages/ray/rllib/policy/policy.py", line 623, in _initialize_loss_from_dummy_batch
    self._dummy_batch, explore=False)
  File "/home/zhaoyong/.conda/envs/py37/lib/python3.7/site-packages/ray/rllib/policy/torch_policy.py", line 262, in compute_actions_from_input_dict
    seq_lens, explore, timestep)
  File "/home/zhaoyong/.conda/envs/py37/lib/python3.7/site-packages/ray/rllib/utils/threading.py", line 21, in wrapper
    return func(self, *a, **k)
  File "/home/zhaoyong/.conda/envs/py37/lib/python3.7/site-packages/ray/rllib/policy/torch_policy.py", line 326, in _compute_action_helper
    seq_lens)
  File "/home/zhaoyong/.conda/envs/py37/lib/python3.7/site-packages/ray/rllib/models/modelv2.py", line 234, in __call__
    res = self.forward(restored, state or [], seq_lens)
  File "/home/zhaoyong/.conda/envs/py37/lib/python3.7/site-packages/ray/rllib/models/torch/recurrent_net.py", line 83, in forward
    output, new_state = self.forward_rnn(inputs, state, seq_lens)
  File "/home/zhaoyong/Codes/Python/RRL/ray_module/custom_model.py", line 61, in forward_rnn
    torch.unsqueeze(state[1], 0)])
  File "/home/zhaoyong/.conda/envs/py37/lib/python3.7/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/home/zhaoyong/.conda/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py", line 574, in forward
    self.check_forward_args(input, hx, batch_sizes)
  File "/home/zhaoyong/.conda/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py", line 529, in check_forward_args
    'Expected hidden[0] size {}, got {}')
  File "/home/zhaoyong/.conda/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py", line 195, in check_hidden_size
    raise RuntimeError(msg.format(expected_hidden_size, tuple(hx.size())))
RuntimeError: Expected hidden[0] size (1, 1, 128), got (1, 32, 128)
(pid=90399) 2021-07-12 19:52:46,616	ERROR worker.py:418 -- Exception raised in creation task: The actor died because of an error raised in its creation task, ray::RolloutWorker.__init__() (pid=90399, ip=192.168.70.128)
(pid=90399)   File "python/ray/_raylet.pyx", line 501, in ray._raylet.execute_task
(pid=90399)   File "python/ray/_raylet.pyx", line 451, in ray._raylet.execute_task.function_executor
(pid=90399)   File "/home/zhaoyong/.conda/envs/py37/lib/python3.7/site-packages/ray/_private/function_manager.py", line 563, in actor_method_executor
(pid=90399)     return method(__ray_actor, *args, **kwargs)
(pid=90399)   File "/home/zhaoyong/.conda/envs/py37/lib/python3.7/site-packages/ray/rllib/evaluation/rollout_worker.py", line 537, in __init__
(pid=90399)     policy_dict, policy_config)
(pid=90399)   File "/home/zhaoyong/.conda/envs/py37/lib/python3.7/site-packages/ray/rllib/evaluation/rollout_worker.py", line 1196, in _build_policy_map
(pid=90399)     policy_map[name] = cls(obs_space, act_space, merged_conf)
(pid=90399)   File "/home/zhaoyong/.conda/envs/py37/lib/python3.7/site-packages/ray/rllib/policy/policy_template.py", line 281, in __init__
(pid=90399)     stats_fn=stats_fn,
(pid=90399)   File "/home/zhaoyong/.conda/envs/py37/lib/python3.7/site-packages/ray/rllib/policy/policy.py", line 623, in _initialize_loss_from_dummy_batch
(pid=90399)     self._dummy_batch, explore=False)
(pid=90399)   File "/home/zhaoyong/.conda/envs/py37/lib/python3.7/site-packages/ray/rllib/policy/torch_policy.py", line 262, in compute_actions_from_input_dict
(pid=90399)     seq_lens, explore, timestep)
(pid=90399)   File "/home/zhaoyong/.conda/envs/py37/lib/python3.7/site-packages/ray/rllib/utils/threading.py", line 21, in wrapper
(pid=90399)     return func(self, *a, **k)
(pid=90399)   File "/home/zhaoyong/.conda/envs/py37/lib/python3.7/site-packages/ray/rllib/policy/torch_policy.py", line 326, in _compute_action_helper
(pid=90399)     seq_lens)
(pid=90399)   File "/home/zhaoyong/.conda/envs/py37/lib/python3.7/site-packages/ray/rllib/models/modelv2.py", line 234, in __call__
(pid=90399)     res = self.forward(restored, state or [], seq_lens)
(pid=90399)   File "/home/zhaoyong/.conda/envs/py37/lib/python3.7/site-packages/ray/rllib/models/torch/recurrent_net.py", line 83, in forward
(pid=90399)     output, new_state = self.forward_rnn(inputs, state, seq_lens)
(pid=90399)   File "/home/zhaoyong/Codes/Python/RRL/ray_module/custom_model.py", line 61, in forward_rnn
(pid=90399)     torch.unsqueeze(state[1], 0)])
(pid=90399)   File "/home/zhaoyong/.conda/envs/py37/lib/python3.7/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
(pid=90399)     result = self.forward(*input, **kwargs)
(pid=90399)   File "/home/zhaoyong/.conda/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py", line 574, in forward
(pid=90399)     self.check_forward_args(input, hx, batch_sizes)
(pid=90399)   File "/home/zhaoyong/.conda/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py", line 529, in check_forward_args
(pid=90399)     'Expected hidden[0] size {}, got {}')
(pid=90399)   File "/home/zhaoyong/.conda/envs/py37/lib/python3.7/site-packages/torch/nn/modules/rnn.py", line 195, in check_hidden_size
(pid=90399)     raise RuntimeError(msg.format(expected_hidden_size, tuple(hx.size())))
(pid=90399) RuntimeError: Expected hidden[0] size (1, 1, 128), got (1, 32, 128)
(pid=90399) 

Process finished with exit code 1

As stated in the previous advice, I have updated my ray version to 1.4.1.
But there is still a problem with this state, how can I solve it?

Hey @Glaucus-2G , thanks for filing this issue, but could you provide a full reproduction script?
I cannot debug this w/o knowing your env’s spaces and other information.

Thanks! :slight_smile:

Hey @sven1977, thanks for your reply, and maybe I think I have solved it.
Error accurs in above code:

def forward_rnn(self, inputs, state, seq_lens):
        obs = inputs[0][0][:8].float().unsqueeze(0).unsqueeze(0)
        obs_r = inputs[0][0][8:].float().unsqueeze(0).unsqueeze(0)

I have found that the inputs may be size (1, 1, 128) or size (1, 32, 128), and I treat it as the former size (1, 1, 128).

I have modified like this:

def forward_rnn(self, inputs, state, seq_lens):
       obs = inputs[:, :, :8]
       obs_r = inputs[:, :, 8:]

Now it can run very well.