Problem with handling states in RNN

How severe does this issue affect your experience of using Ray?

  • High: It blocks me to complete my task.
    Hi everyone,
    I’ve written an environment in which the observation is a NumPy array with a shape of [3,200,4]. In the forward method, I unpack the observation into three tensors of size [batch,200,4]. These observations go through two LSTMs with regimes, and then I have a few FNNs. I’ve tried to handle states by returning them with a shape of [batch,] in the get_initial_state method. I don’t think I should use the Rllib recurrent network class, so I used only TorchModelV2 . For the first two forward passes, everything is fine, but on the third pass, I get some errors about the shape of my states, and I don’t know what to do because as far as I can see, I didn’t do anything wrong with the shapes of my states. Here is the error:
    Expected hidden[0] size (1, 32, 64), got [1, 4, 64]
    and code for reproducing error:
import gymnasium as gym
from gymnasium import spaces
class dummy_env(gym.Env):
  def __init__(self,env_config):
    self.action_space = spaces.Tuple((
    spaces.Box(low=0.8, high=1.2, shape=(1,), dtype=float),
    spaces.Box(low=0, high=0.1, shape=(1,), dtype=float),
    spaces.Box(low=0, high=0.1, shape=(1,), dtype=float)
    self.observation_space = spaces.Box(low=-10, high=10,\
  def reset(self):
    self.done = False
    self.time = 0
    observation = self.get_observations()
    return observation    

  def step(self,action):
    self.time +=1
    a1,a2,a3,a4 = action
    reward = 0
    observation = self.get_observations()
    if self.time == 10:
      self.done = True
    info = {}  
    return observation, reward, self.done, info

  def get_observations(self):
    observation = np.random.uniform(low=1, high=2, size=(3, 200, 4))
    return observation
env_config = {}
import numpy as np
import torch
import torch.nn as nn
import gymnasium as gym
from gymnasium.spaces import Discrete, MultiDiscrete
import tree  # pip install dm_tree
from ray.rllib.utils.annotations import override, DeveloperAPI
from typing import Dict, List, Union, Tuple
from ray.rllib.utils.typing import ModelConfigDict, TensorType
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space
from ray.rllib.models.torch.misc import SlimFC
class LSTMModel(TorchModelV2,nn.Module):
    def __init__(self, obs_space, action_space, num_outputs, model_config, name):
      nn.Module.__init__(self) # Initialize nn.Module before calling super
      super(LSTMModel, self).__init__(obs_space, action_space, num_outputs, model_config, name)

      self.action_space_struct = get_base_struct_from_space(self.action_space)
      self.action_dim = 0

      self.hidden_size = 64
      self.lstm1 = nn.LSTM(input_size=obs_space.shape[2],
      self.lstm2 = nn.LSTM(input_size=64,
      self.fnn = SlimFC(
      in_size=int(self.hidden_size * 3 *0.5),
      self._logits_branch = SlimFC(
      in_size=int(self.hidden_size * 3 *0.5 + 2),
      self._value_branch = SlimFC(
      in_size=int(self.hidden_size * 3 *0.5 + 2),

    def forward(self,input_dict: Dict[str, TensorType],state: List[TensorType],seq_lens: TensorType): 
        '''please document well'''
        # cell_state1,cell_state11,cell_state2,cell_state22,cell_state3,cell_state33 = state
        obs = input_dict["obs"]
        obs1, obs2, obs3 = torch.split(obs, split_size_or_sections=1, dim=1)      
        obs1 = obs1.squeeze(dim=1)
        obs2 = obs2.squeeze(dim=1)
        obs3 = obs3.squeeze(dim=1)
        # print(obs.shape)
        # print(obs1.shape,"input dict obs1") 
        # print(state,"state") 
        # print('state',torch.unsqueeze(state[0], 0).shape)
        lstm_out1, [h1, c1] = self.lstm1(obs1, [torch.unsqueeze(state[0], 0), torch.unsqueeze(state[1], 0)])
        lstm_out11, [h11, c11] = self.lstm2(lstm_out1, [torch.unsqueeze(state[2], 0), torch.unsqueeze(state[3], 0)])
        lstm_out2, [h2, c2] = self.lstm1(obs2, [torch.unsqueeze(state[4], 0), torch.unsqueeze(state[5], 0)])
        lstm_out22, [h22, c22] = self.lstm2(lstm_out2, [torch.unsqueeze(state[6], 0), torch.unsqueeze(state[7], 0)])
        lstm_out3, [h3, c3] = self.lstm1(obs3, [torch.unsqueeze(state[8], 0), torch.unsqueeze(state[9], 0)])
        lstm_out33, [h33, c33] = self.lstm2(lstm_out3, [torch.unsqueeze(state[10], 0), torch.unsqueeze(state[11], 0)])
        # print("outs",lstm_out1.shape,lstm_out11.shape,lstm_out2.shape,lstm_out22.shape)
        lstm_out11 = lstm_out11[:, -1, :]
        lstm_out22 = lstm_out22[:, -1, :]
        lstm_out33 = lstm_out33[:, -1, :]
        #TODO: there are many state handling
        context =[lstm_out11, lstm_out22, lstm_out33], dim=1)
        fnn_output = self.fnn(context)
        a, b = torch.split(fnn_output, 1, dim=1)
        self.concatinated_context =[a, b, context], dim=1)
        output = self._logits_branch(self.concatinated_context)
        # print('hiddenstate', h11.shape)
        # print("h saved",h11.squeeze().shape)
        return output , [h1.squeeze(), c1.squeeze(),h11.squeeze(), c11.squeeze(),
                         h2.squeeze(), c2.squeeze(),h22.squeeze(), c22.squeeze(),
                         h3.squeeze(), c3.squeeze(),h33.squeeze(), c33.squeeze()]

    def get_initial_state(self):
        # Return a list of initial states for each recurrent layer.
        # In this case, there are two LSTM layers, so the list contains two elements.
        fnn = next(self.fnn._model.children())
        h = [,  self.hidden_size).zero_().squeeze(0),,  self.hidden_size).zero_().squeeze(0),,int(0.5 * self.hidden_size)).zero_().squeeze(0),,int(0.5 * self.hidden_size)).zero_().squeeze(0),,  self.hidden_size).zero_().squeeze(0),,  self.hidden_size).zero_().squeeze(0),,int(0.5 * self.hidden_size)).zero_().squeeze(0),,int(0.5 * self.hidden_size)).zero_().squeeze(0),,  self.hidden_size).zero_().squeeze(0),,  self.hidden_size).zero_().squeeze(0),,int(0.5 * self.hidden_size)).zero_().squeeze(0),,int(0.5 * self.hidden_size)).zero_().squeeze(0)

        return h

    def value_function(self) -> TensorType:
        assert self.concatinated_context is not None, "must call forward() first"
        return(torch.reshape(self._value_branch(self.concatinated_context), [-1]))

from ray.rllib.algorithms import ppo
from ray.tune.registry import register_env
def env_creator(env_config):
    return dummy_env(env_config)
model_config = ppo.DEFAULT_CONFIG.copy()    
register_env("dummy_env", env_creator)
model_config["env"] = dummy_env
model_config["env_config"] = env_config
model_config['batch_mode'] = 'truncated_episodes'
model_config['num_gpus'] = 1
model_config['num_gpus_per_worker'] = 0
model_config['num_cpus'] = 1
model_config['num_workers'] = 1
model_config['num_sgd_iter'] = 6
model_config['train_batch_size'] = 1000

model_config["model"]= {'custom_model' : 'LSTMModel'}
import ray
from ray.rllib.models import ModelCatalog
import threading
from tensorboard import notebook
from ray.rllib.algorithms import ppo
from ray import tune
ModelCatalog.register_custom_model('LSTMModel', LSTMModel)
algo = ppo.PPO(config=model_config)
for i in range(10):
    results = algo.train()
    print("Iteration {}, episode_reward_mean = {}".format(
        i, results["episode_reward_mean"]))

> Expected hidden[0] size (1, 32, 64), got [1, 4, 64]

Hi @hossein836,

Try reading this post and see if it helps.

thanks @mannyv for your response as always. :slightly_smiling_face:
I’ve read what you’ve mentioned before and thats a good explanation.
but I think there is a difference in how I used LSTM and how rllib used LSTM. in my model, LSTM is somehow an encoder because in every timestep the model gets a 3D observation(no need to add time dimension).
so it doesn’t matter how many timesteps is in my batch sequence, my real seq_len is always 1 because I don’t need previous observations to train my model. all that I need is just hidden states to be passed to next timestep.
based on my explanations I don’t know what’s going wrong since from what I’ve read, states are being updated every time we call Forward method and will be passed as states in next pass so I thought I have nothing to do with model.view_requirements.
I will appreciate if you can guide me, also code is runnable :wink: