Hi,
I’m trying to use Impala with a custom model to reimplement variations of the network in this paper : [1910.13406] Generalization of Reinforcement Learners with Working and Episodic Memory
So I made an abstract MRA class :
from typing import Type
import gymnasium as gym
import torch
from ray.rllib.models.modelv2 import restore_original_dimensions
from ray.rllib.models.torch.recurrent_net import RecurrentNetwork
from torch import nn
from torch.nn import functional as F
class AbstractMRA(RecurrentNetwork, nn.Module):
def __init__(
self,
obs_space: gym.spaces.Space,
action_space: gym.spaces.Space,
num_outputs: int,
model_config: dict,
name: str,
feature_net: Type[nn.Module],
feature_net_config: dict,
mem_net: Type[nn.Module],
mem_net_config: dict,
working_mem_net: Type[nn.Module],
working_mem_net_config: dict,
) -> None:
RecurrentNetwork.__init__(
self,
obs_space=obs_space,
action_space=action_space,
num_outputs=num_outputs,
model_config=model_config,
name=name,
)
nn.Module.__init__(self)
self.feature_net = feature_net(**feature_net_config)
self.mem_net = mem_net(**mem_net_config)
self.working_mem_net = working_mem_net(**working_mem_net_config)
self._values = None
self.h_t_prev = torch.zeros(
(1, 1, self.working_mem_net.output_dim())
) # can broadcast with size (batch, time, working_mem_features)
def get_initial_state(self):
return (
self.working_mem_net.initial_states()
) # e.g. returns h and c for LSTM, set at 0
def value_function(self):
if self._values is None:
raise RuntimeError("must call forward() first")
return self._values # Size (batch_size, time, 1)
def forward_rnn(self, inputs, state, seq_lens):
original_obs = restore_original_dimensions(
inputs, self.obs_space, tensorlib="torch"
)
picture_obs = original_obs[
"RGB_INTERLEAVED"
] # Size (batch_size, time, width, height, n_channels)
x_t = self.feature_net(picture_obs) # Size (batch_size, time, n_features)
m_t = self.mem_net(
x_t, self.h_t_prev
) # Size (batch_size, time, memory_features)
(
h_t, # size (batch_size, time, working_mem_features)
actions, # size (batch_size, time, num_outputs)
values, # size (batch_size, time, 1)
new_state, # same siza as state
) = self.working_mem_net(torch.concat((x_t, m_t), dim=2), state)
self.h_t_prev = h_t
self._values = values
self.mem_net.write(x_t, h_t)
return actions, new_state
and in order to test it, I did dummy components
class DummyFeatNet(nn.Module):
def __init__(self) -> None:
super().__init__()
self.lin = nn.Linear(in_features=72 * 96 * 3, out_features=512)
def forward(self, x: torch.Tensor) -> torch.Tensor:
output = self.lin(x.flatten(start_dim=2))
output = F.relu(output)
return output
class DummyMem(nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(
self, features: torch.Tensor, prev_hidden: torch.Tensor
) -> torch.Tensor:
return torch.zeros((features.shape[0], features.shape[1], 1))
def write(self, *args) -> None:
pass
class DummyWorkingMem(nn.Module):
def __init__(self) -> None:
super().__init__()
self.lin1 = nn.Linear(in_features=513, out_features=8)
self.lin2 = nn.Linear(in_features=513, out_features=1)
def forward(self, x: torch.Tensor, state):
actions = self.lin1(x)
values = self.lin2(x)
return x, actions, values, state
def output_dim(self):
return 512
def initial_states(self):
return torch.zeros(513)
I then tried running it through
algo = impala.Impala(env = "my_dmm_env", config={
"env_config" : {
"seed" : 123,
"level_name" : "spot_diff_train",
},
"framework": "torch",
"model": {
"custom_model": AbstractMRA,
"custom_model_config": {
"feature_net": DummyFeatNet,
"feature_net_config": {},
"mem_net": DummyMem,
"mem_net_config": {},
"working_mem_net": DummyWorkingMem,
"working_mem_net_config": {},
},
},
})
But I encounter problems… as is, I get this error :
ray/rllib/evaluation/postprocessing.py, line 313, in compute_bootstrap_value
-> sample_batch[SampleBatch.VALUES_BOOTSTRAPPED] = np.concatenate(
ValueError: all the input arrays must have same number of dimensions, but the array at index 0 has 3 dimension(s) and the array at index 1 has 1 dimension(s)
I read the code a bit and figured it was linked to the shapes of values returned by the model, and this is also linked with this post. So I tried to add a squeeze in the value_function
:
def value_function(self):
if self._values is None:
raise RuntimeError("must call forward() first")
return self._values.squeeze(-1) # Size (batch_size, time)
which then got me this error :
ray/rllib/algorithms/impala/vtrace_torch.py, line 310, in from_importance_weights
-> assert rho_rank == len(values.size())
AssertionError
and I can’t figure out what’s wrong.