[RLlib] Shape Error for custom PyTorch model

Hi, I am trying to run PPO with a custom model in PyTorch. The observations are of shape (72, 96, 16) and the forward function generates vectors of length 256 for each observation. According to the documentation the forward function should return: The model output tensor of size [BATCH, num_outputs] and rnn state

I am getting the following error: ValueError: Expected output shape of [None, None], got torch.Size([32, 256])

What am I missing?
Thanks in advance.

import gym, ray
from ray.tune.registry import register_env
import numpy as np
import ray
from ray import tune
from ray.rllib.models import ModelCatalog
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2

from ray.rllib.utils.framework import try_import_tf, try_import_torch
torch, nn = try_import_torch() 

class DummyGFEnv(gym.Env):
    def __init__(self):
        self.observation_space = gym.spaces.Box(low=0, high=255, shape=(72, 96, 16), dtype=np.uint8)
        self.action_space = gym.spaces.discrete.Discrete(19)

    def reset(self):
        return np.zeros((72,96,16), dtype=np.uint8)

    def step(self, action):
        return np.zeros((72,96,16), dtype=np.uint8), 0, False, {}

def dummy_env_creator(something):
    return DummyGFEnv()

ray.init()


class GFImpalaTorch(TorchModelV2, nn.Module):
    """Custom model for policy gradient algorithms."""

    def create_basic_res_block(self, in_channel, out_channel):
        return nn.Sequential(
            nn.ReLU(),
            nn.Conv2d(in_channels=in_channel, out_channels=out_channel, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=out_channel, out_channels=out_channel, kernel_size=3, stride=1, padding=1),
        )

    def __init__(self, obs_space, action_space, num_outputs, model_config,
                 name):
        nn.Module.__init__(self)
        super(GFImpalaTorch, self).__init__(obs_space, action_space, None,
                                                 model_config, name)

        self.conv_layers_config = [(16, 2), (32, 2), (32, 2), (32, 2)]
        self.pool = nn.MaxPool2d(kernel_size=3, stride=2)

        n_input_channels = 16

        self.conv_blocks = [
            nn.Conv2d(in_channels=n_input_channels, out_channels=16, kernel_size=3, stride=1, padding=1),
            nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1),
            nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1),
            nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1)
        ]

        #https://www.tensorflow.org/api_docs/python/tf/nn/pool  -> If padding = "SAME": output_spatial_shape[i] = ceil(    input_spatial_shape[i] / strides[i])
        self.pools = [nn.MaxPool2d(kernel_size=3, stride=2, padding=1) for _ in range(4)]

        self.resblocks_1 = [
            self.create_basic_res_block(16, 16),
            self.create_basic_res_block(32, 32),
            self.create_basic_res_block(32, 32),
            self.create_basic_res_block(32, 32)
        ]
        self.resblocks_2 = [
            self.create_basic_res_block(16, 16),
            self.create_basic_res_block(32, 32),
            self.create_basic_res_block(32, 32),
            self.create_basic_res_block(32, 32)
        ]

        self.relu = nn.ReLU()
        self.flatten = nn.Flatten()
        self.linear = nn.Sequential(nn.Linear(960, 256), nn.ReLU()) #n_flatten=960, features_dim = 256



    def forward(self, input_dict, state, seq_lens):

        observations = input_dict["obs"]
        observations = observations.float()
        observations = observations.permute(0,3,1,2)
        observations /= 255.0


        conv_out = observations
        for i in range(4):
            conv_out = self.conv_blocks[i](conv_out)
            conv_out = self.pools[i](conv_out)

            block_input = conv_out
            conv_out = self.resblocks_1[i](conv_out)
            conv_out += block_input

            block_input = conv_out
            conv_out = self.resblocks_2[i](conv_out)
            conv_out += block_input

        conv_out = self.relu(conv_out)
        conv_out = self.flatten(conv_out)
        conv_out = self.linear(conv_out)


        self._output = conv_out
        return conv_out, state 

    
    def value_function(self):
        assert self._output is not None, "must call forward first!"
        value_out = torch.reshape(self._output, [-1])
        return value_out


register_env("my_env", dummy_env_creator)
ModelCatalog.register_custom_model("gf_impala_cnn_ppo_torch", GFImpalaTorch)


from ray.rllib.agents.ppo import PPOTrainer
tune.run(PPOTrainer, config={
    "env": "my_env", 
    "num_workers": 2,
    'model': {
        "custom_model": "gf_impala_cnn_ppo_torch"
        },
    "framework": "torch",
})

The complete error message:

Failure # 1 (occurred at 2021-03-11_15-51-44)
Traceback (most recent call last):
  File "/Users/azadsalam/anaconda3/envs/rllib_py38/lib/python3.8/site-packages/ray/tune/trial_runner.py", line 586, in _process_trial
    results = self.trial_executor.fetch_result(trial)
  File "/Users/azadsalam/anaconda3/envs/rllib_py38/lib/python3.8/site-packages/ray/tune/ray_trial_executor.py", line 609, in     fetch_result
    result = ray.get(trial_future[0], timeout=DEFAULT_GET_TIMEOUT)
  File "/Users/azadsalam/anaconda3/envs/rllib_py38/lib/python3.8/site-packages/ray/_private/client_mode_hook.py", line 47, in wrapper
    return func(*args, **kwargs)
  File "/Users/azadsalam/anaconda3/envs/rllib_py38/lib/python3.8/site-packages/ray/worker.py", line 1456, in get
    raise value.as_instanceof_cause()
ray.exceptions.RayTaskError(ValueError): e[36mray::PPO.train_buffered()e[39m (pid=3755, ip=192.168.0.12)
  File "python/ray/_raylet.pyx", line 439, in ray._raylet.execute_task
  File "python/ray/_raylet.pyx", line 473, in ray._raylet.execute_task
  File "python/ray/_raylet.pyx", line 476, in ray._raylet.execute_task
  File "python/ray/_raylet.pyx", line 480, in ray._raylet.execute_task
  File "python/ray/_raylet.pyx", line 432, in ray._raylet.execute_task.function_executor
  File "/Users/azadsalam/anaconda3/envs/rllib_py38/lib/python3.8/site-packages/ray/rllib/agents/trainer_template.py", line 107, in     __init__
    Trainer.__init__(self, config, env, logger_creator)
  File "/Users/azadsalam/anaconda3/envs/rllib_py38/lib/python3.8/site-packages/ray/rllib/agents/trainer.py", line 486, in __init__
    super().__init__(config, logger_creator)
  File "/Users/azadsalam/anaconda3/envs/rllib_py38/lib/python3.8/site-packages/ray/tune/trainable.py", line 97, in __init__
    self.setup(copy.deepcopy(self.config))
  File "/Users/azadsalam/anaconda3/envs/rllib_py38/lib/python3.8/site-packages/ray/rllib/agents/trainer.py", line 654, in setup
    self._init(self.config, self.env_creator)
  File "/Users/azadsalam/anaconda3/envs/rllib_py38/lib/python3.8/site-packages/ray/rllib/agents/trainer_template.py", line 134, in _init
    self.workers = self._make_workers(
  File "/Users/azadsalam/anaconda3/envs/rllib_py38/lib/python3.8/site-packages/ray/rllib/agents/trainer.py", line 725, in _make_workers
    return WorkerSet(
  File "/Users/azadsalam/anaconda3/envs/rllib_py38/lib/python3.8/site-packages/ray/rllib/evaluation/worker_set.py", line 79, in __init__
    remote_spaces = ray.get(self.remote_workers(
  File "/Users/azadsalam/anaconda3/envs/rllib_py38/lib/python3.8/site-packages/ray/_private/client_mode_hook.py", line 47, in wrapper
    return func(*args, **kwargs)
ray.exceptions.RayTaskError(ValueError): e[36mray::RolloutWorker.foreach_policy()e[39m (pid=3756, ip=192.168.0.12)
  File "python/ray/_raylet.pyx", line 439, in ray._raylet.execute_task
  File "python/ray/_raylet.pyx", line 473, in ray._raylet.execute_task
  File "python/ray/_raylet.pyx", line 476, in ray._raylet.execute_task
  File "python/ray/_raylet.pyx", line 480, in ray._raylet.execute_task
  File "python/ray/_raylet.pyx", line 432, in ray._raylet.execute_task.function_executor
  File "/Users/azadsalam/anaconda3/envs/rllib_py38/lib/python3.8/site-packages/ray/rllib/evaluation/rollout_worker.py", line 479, in     __init__
    self.policy_map, self.preprocessors = self._build_policy_map(
  File "/Users/azadsalam/anaconda3/envs/rllib_py38/lib/python3.8/site-packages/ray/rllib/evaluation/rollout_worker.py", line 1111, in     _build_policy_map
    policy_map[name] = cls(obs_space, act_space, merged_conf)
  File "/Users/azadsalam/anaconda3/envs/rllib_py38/lib/python3.8/site-packages/ray/rllib/policy/policy_template.py", line 266, in     __init__
    self._initialize_loss_from_dummy_batch(
  File "/Users/azadsalam/anaconda3/envs/rllib_py38/lib/python3.8/site-packages/ray/rllib/policy/policy.py", line 622, in     _initialize_loss_from_dummy_batch
    self.compute_actions_from_input_dict(input_dict, explore=False)
  File "/Users/azadsalam/anaconda3/envs/rllib_py38/lib/python3.8/site-packages/ray/rllib/policy/torch_policy.py", line 207, in     compute_actions_from_input_dict
    return self._compute_action_helper(input_dict, state_batches,
  File "/Users/azadsalam/anaconda3/envs/rllib_py38/lib/python3.8/site-packages/ray/rllib/utils/threading.py", line 21, in wrapper
    return func(self, *a, **k)
  File "/Users/azadsalam/anaconda3/envs/rllib_py38/lib/python3.8/site-packages/ray/rllib/policy/torch_policy.py", line 249, in     _compute_action_helper
    dist_inputs, state_out = self.model(input_dict, state_batches,
  File "/Users/azadsalam/anaconda3/envs/rllib_py38/lib/python3.8/site-packages/ray/rllib/models/modelv2.py", line 225, in __call__
    raise ValueError(
ValueError: Expected output shape of [None, None], got torch.Size([32, 256])

@sven1977 could you help take a look here?