Custom RNN Model with Examples - why do they fail?

Using the current pip install ray[debug] and ray[rllib], here is my minimum reproducible example #1 using dictionary observations.

This fails with the following error:
RuntimeError: Expected hidden[0] size (1, 1, 256), got [1, 32, 256]

import sys
import numpy as np
import gym
from gym import spaces
from ray.rllib.models import ModelCatalog
from ray.tune.registry import register_env
from ray import tune
from ray.rllib.models.modelv2 import restore_original_dimensions
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.preprocessors import get_preprocessor
from ray.rllib.models.tf.recurrent_net import RecurrentNetwork
from ray.rllib.models.torch.recurrent_net import RecurrentNetwork as TorchRNN
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_torch
torch, nn = try_import_torch()

class TestingGym(gym.Env):
    metadata = {'render.modes': ['human']}
    def __init__(self, timesteps=5):
        self.timesteps = timesteps
        super(TestingGym, self).__init__()
        self.reward_range = (-1000, 1000)
        self.action_space = spaces.Box( low=np.array([0, 0]), high=np.array([4, 1]) )
        self.done_counter = 0
        self.input_1_shape = (16,)
        self.input_2_shape = (16,)

        self.observation_space = spaces.Dict(
            dict(
              input_1=spaces.Box(low=-np.inf, high=np.inf, shape=self.input_1_shape, dtype=np.float32),
              input_2=spaces.Box(low=-np.inf, high=np.inf, shape=self.input_2_shape, dtype=np.float32)
            )
        )
    def get_observation(self):
        curr_obs = dict( input_1 = np.random.random( self.input_1_shape ),
                                  input_2 = np.random.random( self.input_2_shape )   )
        return curr_obs

    def step(self, action):
        self.done_counter += 1
        if self.done_counter > 1000:
            done = True
        else:
            done = False
        return self.get_observation(), 1, done, {}

    def reset(self):
        self.done_counter = 0
        return self.get_observation()


def env_creator(env_config):
    env = TestingGym()
    return env

class TorchRNNModel(TorchRNN, nn.Module):
    def __init__(self,
                 obs_space,
                 action_space,
                 num_outputs,
                 model_config,
                 name,
                 fc_size=32,
                 lstm_state_size=256):
        nn.Module.__init__(self)
        super().__init__(obs_space, action_space, num_outputs, model_config,
                         name)
        self.obs_space = obs_space
        self.obs_size = get_preprocessor(obs_space)(obs_space).size
        self.fc_size = fc_size
        self.lstm_state_size = lstm_state_size

        self.input_1_fc = nn.Linear(16, 16)
        self.input_2_fc = nn.Linear(16, 16)

        self.fc1 = nn.Linear(32, self.fc_size)
        self.lstm = nn.LSTM( self.fc_size, self.lstm_state_size, batch_first=True)
        self.action_branch = nn.Linear(self.lstm_state_size, num_outputs)
        self.value_branch = nn.Linear(self.lstm_state_size, 1)
        self._features = None

    @override(ModelV2)
    def get_initial_state(self):
        h = [
            self.fc1.weight.new(1, self.lstm_state_size).zero_().squeeze(0),
            self.fc1.weight.new(1, self.lstm_state_size).zero_().squeeze(0)
        ]
        return h
    @override(ModelV2)
    def value_function(self):
        assert self._features is not None, "must call forward() first"
        return torch.reshape(self.value_branch(self._features), [-1])

    @override(TorchRNN)
    def forward_rnn(self, inputs, state, seq_lens):
        original_obs = restore_original_dimensions(
            torch.squeeze(inputs,1) , self.obs_space, "torch")

        x1 = nn.functional.relu( self.input_1_fc( original_obs['input_1'] ) )
        x2 = nn.functional.relu( self.input_2_fc( original_obs['input_2'] ) )
        # Join the outputs
        x = torch.cat((x2, x1), dim=1)

        x = nn.functional.relu(self.fc1(x))

        x = torch.unsqueeze( x, 0)
        self._features, [h, c] = self.lstm(
            x, [torch.unsqueeze(state[0], 0),
                torch.unsqueeze(state[1], 0)])
        action_out = self.action_branch(self._features)
        return action_out, [torch.squeeze(h, 0), torch.squeeze(c, 0)]


import ray
import ray.rllib.agents.a3c as a3c
from ray.rllib.models.preprocessors import get_preprocessor
import copy


ray.shutdown(); ray.init()

ModelCatalog.register_custom_model("torch_rnn_model", TorchRNNModel)
tune.registry.register_env(u"TestingGym", env_creator)

trainer = a3c.A2CTrainer(
        env = "TestingGym",
        config={
            "num_workers": 1,
            "lr": 0.000001,
            "framework": "torch",
            "model": { "custom_model": "torch_rnn_model" }        
          },
        )
for i in range(10):
    result = trainer.train()
    clear_output()
    print(pretty_print(result))

And another attempt, this time using the TorchModelV2 with use_lstm or use_attention… both fail. use_lstm fails with

RuntimeError: input.size(-1) must be equal to input_size. Expected 32, got 16

It ( forward_rnn ) is expecting the original observation size after forward. This is the biggest issue because we want to process our observation and then pass it to the LSTM. Are we missing something?

from IPython.display import clear_output
import torch.nn as nn

import ray
from ray.tune.logger import pretty_print
from ray.rllib.agents import ppo
from ray.rllib.models import ModelCatalog
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2

class CustomTorchModel(TorchModelV2, nn.Module):
    def __init__(self, obs_space, action_space, num_outputs, model_config, name):
        TorchModelV2.__init__(self, obs_space, action_space, num_outputs, model_config, name)
        nn.Module.__init__(self)

        self.input_1_fc = nn.Linear(16, 16)
        self.input_2_fc = nn.Linear(16, 16)

        self.fc1 = nn.Linear(32, 16)
        num_outputs=2
        # self.action_branch = nn.Linear(32, num_outputs)
        self.value_branch = nn.Linear(32, 1)
        # self._logits = ...
        self._features = None
              
    def forward(self, input_dict, state, seq_lens):
        x1 = nn.functional.relu( self.input_1_fc( input_dict['obs']['input_1'] ) )
        x2 = nn.functional.relu( self.input_2_fc( input_dict['obs']['input_2'] ) )
        x = torch.cat((x2, x1), dim=1)
        self._features = nn.functional.relu(self.fc1(x))
        # action_out = self.action_branch(self._features)
        return self._features, state

    @override(ModelV2)
    def value_function(self):
        assert self._features is not None, "must call forward() first"
        return torch.reshape(self.value_branch(self._features), [-1])


ray.shutdown(); ray.init()

ModelCatalog.register_custom_model("my_torch_model", CustomTorchModel)
tune.registry.register_env(u"TestingGym", env_creator)

trainer = ppo.PPOTrainer(env="TestingGym",  
  config={
    "framework": "torch",
    "model": {
        "use_lstm":True,
        "custom_model": "my_torch_model",
    },
})
for i in range(10):
    result = trainer.train()
    clear_output()
    print(pretty_print(result))

Please correct me, please be critical and tell me the obvious things I’m missing and any other mistakes we see… like passing self._features, or commenting out the action_branch. Side questions, aren’t self._logits and self._features the same thing? I know logits is the tensor before the last output layer. Perhaps self._features is the designated output for the hidden layers.

I’m pretty sure the RNN inputs come in the shape [Batch, Time, Feature]. You want outputs of this format as well.

Thanks, you’re correct and that’s how they’re coming into forward_rnn. This seems to be an old issue with initializing the hidden state of LSTM inside the
get_initial_state definition. Forward and forward_rnn passes run 11 times, and then on the 12th we get the first error

Expected hidden[0] size (1, 1, 256), got [1, 32, 256]

I’m also looking at the first line inside forward_rnn restore_original_dimensions. I changed it to the following, which yielded the same results:

original_obs = restore_original_dimensions( inputs , self.obs_space, "torch")        
original_obs['input_1'] = torch.squeeze(original_obs['input_1'], 2)
original_obs['input_2'] = torch.squeeze(original_obs['input_2'], 2)

If we omit the last 2 lines and don’t remove a dimension from the observation, it also runs 12 times and yields:

RuntimeError: Expected 4-dimensional input for 4-dimensional weight [28, 1, 28, 28], 
but got 5-dimensional input of size [32, 1, 1, 28, 28] instead

@sven1977 you’ve worked on PRs related to this before, your thoughts?

Hi @Gregory,

The main issue is in two lines:

    original_obs = restore_original_dimensions(
            torch.squeeze(inputs,1) , self.obs_space, "torch")
    ...
    x = torch.cat((x2, x1), dim=1)

In the case where it is failing if you look at the shape of inputs it is: torch.Size([4, 8, 32])

torch.squeeze() will have no effect here so the shape of input_1 and input_2 in original_obs are: torch.Size([4, 8, 16])

When you cat dim 1 you end up with: torch.Size([4, 16, 16])

What you really want to do is: x = torch.cat((x2, x1), dim=-1)

@override(TorchRNN)
def forward_rnn(self, inputs, state, seq_lens):
    original_obs = restore_original_dimensions(
        inputs , self.obs_space, "torch")

    x1 = nn.functional.relu( self.input_1_fc( original_obs['input_1'] ) )
    x2 = nn.functional.relu( self.input_2_fc( original_obs['input_2'] ) )
    # Join the outputs
    x = torch.cat((x2, x1), dim=-1)

    x = nn.functional.relu(self.fc1(x))

    self._features, [h, c] = self.lstm(
        x, [torch.unsqueeze(state[0], 0),
            torch.unsqueeze(state[1], 0)])
    action_out = self.action_branch(self._features)
    return action_out, [torch.squeeze(h, 0), torch.squeeze(c, 0)]
1 Like

@mannyv you’re my savior, thank you for taking some of your time to share the solution with me. :pray:t4: :pray:t4: :pray:t4: Hopefully this thread and your answer also helps others in the future who might find themselves in a similar situation.

A follow-up of the same example but using a convolutional layer? It’s the same as before, with your corrections, where input 1 is an image. This input is squeezed because the convolutional layer expects 4 dimensions (including batch) and not 5 including the added time dimension for RNN. Here’s the minimal example that ends with this error:

RuntimeError: Expected hidden[0] size (1, 1, 256), got [1, 32, 256]

edit: I’m finding that squeeze is likely not appropriate. And studying here about what to do in this case

import sys
import numpy as np
import copy
from ray.rllib.models import ModelCatalog
from ray.tune.registry import register_env
from ray import tune
import gym
from gym import spaces
import ray
import ray.rllib.agents.a3c as a3c
from ray.rllib.models.preprocessors import get_preprocessor
from ray.rllib.models.modelv2 import restore_original_dimensions
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.preprocessors import get_preprocessor
from ray.rllib.models.tf.recurrent_net import RecurrentNetwork
from ray.rllib.models.torch.recurrent_net import RecurrentNetwork as TorchRNN
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_tf, try_import_torch
torch, nn = try_import_torch()

class TestingGym(gym.Env):
    metadata = {'render.modes': ['human']}

    def __init__(self, timesteps=5):
        self.timesteps = timesteps
        super(TestingGym, self).__init__()
        self.reward_range = (-sys.float_info.max-1, sys.float_info.max)
        self.action_space = spaces.Box(low=np.array([0, 0]), high=np.array([4, 1]), dtype=np.float16)
        self.done_counter = 0
    
        self.input_1_shape = (1, 28, 28)
        self.input_2_shape = (1, 16)

        self.observation_space = spaces.Dict(
            dict(
              input_1=spaces.Box(low=-sys.float_info.max-1, high=sys.float_info.max, shape=self.input_1_shape, dtype=np.float32),
              input_2=spaces.Box(low=-sys.float_info.max-1, high=sys.float_info.max, shape=self.input_2_shape, dtype=np.float32)
            )
        )
    def get_observation(self):
        curr_obs = dict( input_1 = np.random.random( self.input_1_shape ),
                         input_2 = np.random.random( self.input_2_shape ) 
                         )
        return curr_obs

    def step(self, action):
        self.done_counter += 1
        if self.done_counter > 1000:
            done = True
        else:
            done = False
        return self.get_observation(), 1, done, {}

    def reset(self):
        self.done_counter = 0
        return self.get_observation()

def env_creator(env_config):
    env = TestingGym()
    return env

class TorchRNNModel(TorchRNN, nn.Module):
    def __init__(self,
                 obs_space,
                 action_space,
                 num_outputs,
                 model_config,
                 name,
                 fc_size=44,
                 lstm_state_size=256):
        nn.Module.__init__(self)
        super().__init__(obs_space, action_space, num_outputs, model_config,
                         name)
        self.obs_space = obs_space
        self.obs_size = get_preprocessor(obs_space)(obs_space).size
        self.fc_size = fc_size
        self.lstm_state_size = lstm_state_size

        self.input_1_conv2d = nn.Conv2d(1, 28, 28, stride=1)
        self.input_2_fc = nn.Linear(16, 16)

        self.fc1 = nn.Linear(44, self.fc_size)
        self.lstm = nn.LSTM(
            self.fc_size, self.lstm_state_size, batch_first=True)
        self.action_branch = nn.Linear(self.lstm_state_size, num_outputs)
        self.value_branch = nn.Linear(self.lstm_state_size, 1)
        self._features = None

    @override(ModelV2)
    def get_initial_state(self):
        h = [
            self.fc1.weight.new(1, self.lstm_state_size).zero_().squeeze(0),
            self.fc1.weight.new(1, self.lstm_state_size).zero_().squeeze(0)
        ]
        return h

    @override(ModelV2)
    def value_function(self):
        assert self._features is not None, "must call forward() first"
        return torch.reshape(self.value_branch(self._features), [-1])

    @override(TorchRNN)
    def forward_rnn(self, inputs, state, seq_lens):
        original_obs = restore_original_dimensions( inputs , self.obs_space, "torch")        

        print(f"inputs shape {inputs.shape}")
        print(f"input_1 shape: {original_obs['input_1'].shape}")
        print(f"input_2 shape: {original_obs['input_2'].shape}")

        original_obs['input_1'] = torch.squeeze(original_obs['input_1'],1)
        original_obs['input_2'] = torch.squeeze(original_obs['input_2'],1)
        
        x = self.input_1_conv2d( original_obs['input_1'] )
        print(f"input_1 out shape {x.shape}")

        x = x.contiguous().view(-1, 28) # x.shape[0]
        print(f"input_1 out flat shape {x.shape}")
        

        x2 = nn.functional.relu( self.input_2_fc( original_obs['input_2'] ) )
        print(f"input_2 out shape {x2.shape}")
        
        x2 = x2.contiguous().view(-1, 16) # x2.shape[0]
        print(f"input_2 out flat shape {x2.shape}")
        

        # Join the outputs
        x = torch.cat((x2, x), dim=-1)

        x = nn.functional.relu(self.fc1(x))

        x = torch.unsqueeze( x, 0)
        self._features, [h, c] = self.lstm(
            x, [torch.unsqueeze(state[0], 0),
                torch.unsqueeze(state[1], 0)])
        action_out = self.action_branch(self._features)
        return action_out, [torch.squeeze(h, 0), torch.squeeze(c, 0)]

ray.shutdown(); ray.init()

ModelCatalog.register_custom_model("torch_rnn_model", TorchRNNModel)
tune.registry.register_env(u"TestingGym", env_creator)

trainer = a3c.A2CTrainer(
        env = "TestingGym",
        config={
            "num_workers": 1,
            "lr": 0.000001,
            "framework": "torch",
            "model": {
                "custom_model": "torch_rnn_model",
            } }, )
            
for i in range(10):
    result = trainer.train()
    clear_output()
    print(pretty_print(result))

The other model using TorchModelV2 is also so repaired with changing torch.cat to use dim=-1, and adjusting the sizes in the layers. Here’s the working copy in case it helps anyone down the road.

class CustomTorchModel(TorchModelV2, nn.Module):
    def __init__(self, obs_space, action_space, num_outputs, model_config, name):
        TorchModelV2.__init__(self, obs_space, action_space, num_outputs, model_config, name)
        nn.Module.__init__(self)

        self.input_1_fc = nn.Linear(16, 16)
        self.input_2_fc = nn.Linear(16, 16)

        self.fc1 = nn.Linear(32, 32)
        self.value_branch = nn.Linear(32, 1)
        self._features = None
              
    def forward(self, input_dict, state, seq_lens):
        x1 = nn.functional.relu( self.input_1_fc( input_dict['obs']['input_1'] ) )
        x2 = nn.functional.relu( self.input_2_fc( input_dict['obs']['input_2'] ) )
        x = torch.cat((x2, x1), dim=-1)
        self._features = nn.functional.relu(self.fc1(x))
        return self._features, state

    @override(ModelV2)
    def value_function(self):
        assert self._features is not None, "must call forward() first"
        return torch.reshape(self.value_branch(self._features), [-1])

I was going to give you this answer last time but when you were just using linear layers it was easy enough to do in forward_rnn since it will apply the inputs along the last dimension for you automatically. When you are using a conv layer though, you really want to operate on the input before it has been reshaped to include the time dimension. The way you would do this is to override the forward function as well as the forward_rnn method. The only real expectation the default TorchRNN.forward has that you need to observe is that the input will be in a dictionary with the key “obs_flat”. So the strategy then is to implement forward. Apply whatever processing you want to it, then pass that to the super class forward method. At this point though since you are not really doing anything non-standard with the lstm or output from it, it would be better to implement it as a TorchModelV2 and include “use_lstm” in the config to use the default LSTM wrapper. I am not sure why you made input2’s shape (1,16). I changed it to input2_shape = (16,) to drop that singleton dimension. If you wanted to keep it as (1,16) I left the squeeze you would need as a comment.

@override(TorchRNN)
def forward(self, input_dict, state, seq_lens):
    original_obs = input_dict["obs"]
    print(f"input_1 shape: {original_obs['input_1'].shape}")
    print(f"input_2 shape: {original_obs['input_2'].shape}")

    x = self.input_1_conv2d( original_obs['input_1'] )
    print(f"input_1 out shape {x.shape}")

    x = x.flatten(-3,-1)
    print(f"input_1 out squeeze shape {x.shape}")

    x2 = nn.functional.relu( self.input_2_fc( original_obs['input_2'] ) )
    print(f"input_2 out shape {x2.shape}")
    # x2 = x2.squeeze(1)
    # print(f"input_2 out squeeze shape {x2.shape}")

    # Join the outputs
    x = torch.cat((x2, x), dim=-1)
    x = nn.functional.relu(self.fc1(x))
    return super().forward({"obs_flat": x}, state, seq_lens)

@override(TorchRNN)
def forward_rnn(self, inputs, state, seq_lens):
    print(f"forward_rnn inputs shape {inputs.shape}")
    self._features, [h, c] = self.lstm(inputs,
                                       [torch.unsqueeze(state[0], 0),
                                        torch.unsqueeze(state[1], 0)])
    action_out = self.action_branch(self._features)
    return action_out, [torch.squeeze(h, 0), torch.squeeze(c, 0)]
3 Likes

Thanks Manny for your generous help. Your response is pure gold to me. I managed to get it running inside forward_rnn based on this post, by using view to reshape and multiplying the batch and time dimensions together, then restoring those dimensions after the convolution layer… all the while thinking there has to be a more efficient way!

What you suggest by overriding the forward function and calling the super class forward method is the correct solution and something that eluded me for so long as a self-taught coder. As for the input2’s (1,16) shape… it was just a random selection when simplifying the example. My most sincere compliments to you and the skill level you’ve attained. Thank you again for sharing your valuable experience my friend :pray:t4: :tophat:

1 Like

Thanks @mannyv and @Gregory for the question and solution!
Just a heads up, we are currently thinking about simplifying RLlib’s Model API by a lot, meaning we would allow users to specify native keras or torch models (no more ModelV2), get rid of our default Preprocessors (you would get observations into your model as they come from the env; no more intransparent flattening operations), no more need for a forward_rnn, etc…
We are hoping this will alleviate lots of issues users currently have especially with RNNs or complex model architectures and input spaces.

1 Like

@Gregory , Thanks for providing a working copy. I am trying to use a similar approach, setting use_lstm as True. Additionally, I am trying to enable action masking. Do you know which function should I override achieve that? A bit confused by the documentation on this matter. Setting the flag as True with action masking would give an error. Thanks!