How severe does this issue affect your experience of using Ray?
- High: It blocks me to complete my task.
Hi everyone,
I’ve written an environment in which the observation is a NumPy array with a shape of [3,200,4]. In the forward method, I unpack the observation into three tensors of size [batch,200,4]. These observations go through two LSTMs with regimes, and then I have a few FNNs. I’ve tried to handle states by returning them with a shape of [batch,] in theget_initial_state
method. I don’t think I should use the Rllib recurrent network class, so I used onlyTorchModelV2
. For the first two forward passes, everything is fine, but on the third pass, I get some errors about the shape of my states, and I don’t know what to do because as far as I can see, I didn’t do anything wrong with the shapes of my states. Here is the error:
Expected hidden[0] size (1, 32, 64), got [1, 4, 64]
and code for reproducing error:
import gymnasium as gym
from gymnasium import spaces
class dummy_env(gym.Env):
def __init__(self,env_config):
self.action_space = spaces.Tuple((
spaces.Discrete(3),
spaces.Box(low=0.8, high=1.2, shape=(1,), dtype=float),
spaces.Box(low=0, high=0.1, shape=(1,), dtype=float),
spaces.Box(low=0, high=0.1, shape=(1,), dtype=float)
))
self.observation_space = spaces.Box(low=-10, high=10,\
shape=(3,200,4),\
dtype=np.float32)
def reset(self):
self.done = False
self.time = 0
observation = self.get_observations()
return observation
def step(self,action):
self.time +=1
a1,a2,a3,a4 = action
reward = 0
observation = self.get_observations()
if self.time == 10:
self.done = True
info = {}
return observation, reward, self.done, info
def get_observations(self):
observation = np.random.uniform(low=1, high=2, size=(3, 200, 4))
return observation
env_config = {}
import numpy as np
import torch
import torch.nn as nn
import gymnasium as gym
from gymnasium.spaces import Discrete, MultiDiscrete
import tree # pip install dm_tree
from ray.rllib.utils.annotations import override, DeveloperAPI
from typing import Dict, List, Union, Tuple
from ray.rllib.utils.typing import ModelConfigDict, TensorType
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space
from ray.rllib.models.torch.misc import SlimFC
class LSTMModel(TorchModelV2,nn.Module):
def __init__(self, obs_space, action_space, num_outputs, model_config, name):
nn.Module.__init__(self) # Initialize nn.Module before calling super
super(LSTMModel, self).__init__(obs_space, action_space, num_outputs, model_config, name)
self.action_space_struct = get_base_struct_from_space(self.action_space)
self.action_dim = 0
self.hidden_size = 64
self.lstm1 = nn.LSTM(input_size=obs_space.shape[2],
hidden_size=64,
num_layers=1,
batch_first=True,
dropout=0.2)
self.lstm2 = nn.LSTM(input_size=64,
hidden_size=32,
num_layers=1,
batch_first=True,
dropout=0.2)
self.fnn = SlimFC(
in_size=int(self.hidden_size * 3 *0.5),
out_size=2,
activation_fn=None,
initializer=torch.nn.init.xavier_uniform_,
)
self._logits_branch = SlimFC(
in_size=int(self.hidden_size * 3 *0.5 + 2),
out_size=num_outputs,
activation_fn=None,
initializer=torch.nn.init.xavier_uniform_,
)
self._value_branch = SlimFC(
in_size=int(self.hidden_size * 3 *0.5 + 2),
out_size=1,
activation_fn=None,
initializer=torch.nn.init.xavier_uniform_,
)
def forward(self,input_dict: Dict[str, TensorType],state: List[TensorType],seq_lens: TensorType):
'''please document well'''
# cell_state1,cell_state11,cell_state2,cell_state22,cell_state3,cell_state33 = state
obs = input_dict["obs"]
obs1, obs2, obs3 = torch.split(obs, split_size_or_sections=1, dim=1)
obs1 = obs1.squeeze(dim=1)
obs2 = obs2.squeeze(dim=1)
obs3 = obs3.squeeze(dim=1)
# print(obs.shape)
# print(obs1.shape,"input dict obs1")
# print(state,"state")
# print('state',torch.unsqueeze(state[0], 0).shape)
lstm_out1, [h1, c1] = self.lstm1(obs1, [torch.unsqueeze(state[0], 0), torch.unsqueeze(state[1], 0)])
lstm_out11, [h11, c11] = self.lstm2(lstm_out1, [torch.unsqueeze(state[2], 0), torch.unsqueeze(state[3], 0)])
lstm_out2, [h2, c2] = self.lstm1(obs2, [torch.unsqueeze(state[4], 0), torch.unsqueeze(state[5], 0)])
lstm_out22, [h22, c22] = self.lstm2(lstm_out2, [torch.unsqueeze(state[6], 0), torch.unsqueeze(state[7], 0)])
lstm_out3, [h3, c3] = self.lstm1(obs3, [torch.unsqueeze(state[8], 0), torch.unsqueeze(state[9], 0)])
lstm_out33, [h33, c33] = self.lstm2(lstm_out3, [torch.unsqueeze(state[10], 0), torch.unsqueeze(state[11], 0)])
# print("outs",lstm_out1.shape,lstm_out11.shape,lstm_out2.shape,lstm_out22.shape)
lstm_out11 = lstm_out11[:, -1, :]
lstm_out22 = lstm_out22[:, -1, :]
lstm_out33 = lstm_out33[:, -1, :]
#TODO: there are many state handling
context = torch.cat([lstm_out11, lstm_out22, lstm_out33], dim=1)
fnn_output = self.fnn(context)
a, b = torch.split(fnn_output, 1, dim=1)
self.concatinated_context = torch.cat([a, b, context], dim=1)
output = self._logits_branch(self.concatinated_context)
# print('hiddenstate', h11.shape)
# print("h saved",h11.squeeze().shape)
return output , [h1.squeeze(), c1.squeeze(),h11.squeeze(), c11.squeeze(),
h2.squeeze(), c2.squeeze(),h22.squeeze(), c22.squeeze(),
h3.squeeze(), c3.squeeze(),h33.squeeze(), c33.squeeze()]
def get_initial_state(self):
# Return a list of initial states for each recurrent layer.
# In this case, there are two LSTM layers, so the list contains two elements.
fnn = next(self.fnn._model.children())
h = [
fnn.weight.new(1, self.hidden_size).zero_().squeeze(0),
fnn.weight.new(1, self.hidden_size).zero_().squeeze(0),
fnn.weight.new(1,int(0.5 * self.hidden_size)).zero_().squeeze(0),
fnn.weight.new(1,int(0.5 * self.hidden_size)).zero_().squeeze(0),
fnn.weight.new(1, self.hidden_size).zero_().squeeze(0),
fnn.weight.new(1, self.hidden_size).zero_().squeeze(0),
fnn.weight.new(1,int(0.5 * self.hidden_size)).zero_().squeeze(0),
fnn.weight.new(1,int(0.5 * self.hidden_size)).zero_().squeeze(0),
fnn.weight.new(1, self.hidden_size).zero_().squeeze(0),
fnn.weight.new(1, self.hidden_size).zero_().squeeze(0),
fnn.weight.new(1,int(0.5 * self.hidden_size)).zero_().squeeze(0),
fnn.weight.new(1,int(0.5 * self.hidden_size)).zero_().squeeze(0)
]
return h
def value_function(self) -> TensorType:
assert self.concatinated_context is not None, "must call forward() first"
return(torch.reshape(self._value_branch(self.concatinated_context), [-1]))
from ray.rllib.algorithms import ppo
from ray.tune.registry import register_env
def env_creator(env_config):
return dummy_env(env_config)
model_config = ppo.DEFAULT_CONFIG.copy()
register_env("dummy_env", env_creator)
model_config["env"] = dummy_env
model_config["env_config"] = env_config
model_config['batch_mode'] = 'truncated_episodes'
model_config['num_gpus'] = 1
model_config['num_gpus_per_worker'] = 0
model_config['num_cpus'] = 1
model_config['num_workers'] = 1
model_config['num_sgd_iter'] = 6
model_config["framework"]="torch"
model_config['train_batch_size'] = 1000
model_config["model"]= {'custom_model' : 'LSTMModel'}
import ray
from ray.rllib.models import ModelCatalog
import threading
from tensorboard import notebook
from ray.rllib.algorithms import ppo
from ray import tune
ray.shutdown()
ray.init(ignore_reinit_error=True)
ModelCatalog.register_custom_model('LSTMModel', LSTMModel)
algo = ppo.PPO(config=model_config)
for i in range(10):
results = algo.train()
print("Iteration {}, episode_reward_mean = {}".format(
i, results["episode_reward_mean"]))
> Expected hidden[0] size (1, 32, 64), got [1, 4, 64]