Custom Impala model

Hi,
I’m trying to use Impala with a custom model to reimplement variations of the network in this paper : [1910.13406] Generalization of Reinforcement Learners with Working and Episodic Memory

So I made an abstract MRA class :

from typing import Type

import gymnasium as gym
import torch
from ray.rllib.models.modelv2 import restore_original_dimensions
from ray.rllib.models.torch.recurrent_net import RecurrentNetwork
from torch import nn
from torch.nn import functional as F


class AbstractMRA(RecurrentNetwork, nn.Module):
    def __init__(
        self,
        obs_space: gym.spaces.Space,
        action_space: gym.spaces.Space,
        num_outputs: int,
        model_config: dict,
        name: str,
        feature_net: Type[nn.Module],
        feature_net_config: dict,
        mem_net: Type[nn.Module],
        mem_net_config: dict,
        working_mem_net: Type[nn.Module],
        working_mem_net_config: dict,
    ) -> None:
        RecurrentNetwork.__init__(
            self,
            obs_space=obs_space,
            action_space=action_space,
            num_outputs=num_outputs,
            model_config=model_config,
            name=name,
        )
        nn.Module.__init__(self)
        self.feature_net = feature_net(**feature_net_config)
        self.mem_net = mem_net(**mem_net_config)
        self.working_mem_net = working_mem_net(**working_mem_net_config)
        self._values = None
        self.h_t_prev = torch.zeros(
            (1, 1, self.working_mem_net.output_dim())
        )  # can broadcast with size (batch, time, working_mem_features)

    def get_initial_state(self):
        return (
            self.working_mem_net.initial_states()
        )  # e.g. returns h and c for LSTM, set at 0

    def value_function(self):
        if self._values is None:
            raise RuntimeError("must call forward() first")
        return self._values  # Size (batch_size, time, 1)

    def forward_rnn(self, inputs, state, seq_lens):
        original_obs = restore_original_dimensions(
            inputs, self.obs_space, tensorlib="torch"
        )
        picture_obs = original_obs[
            "RGB_INTERLEAVED"
        ]  # Size (batch_size, time, width, height, n_channels)
        x_t = self.feature_net(picture_obs)  # Size (batch_size, time, n_features)
        m_t = self.mem_net(
            x_t, self.h_t_prev
        )  # Size (batch_size, time, memory_features)
        (
            h_t,  # size (batch_size, time, working_mem_features)
            actions,  # size (batch_size, time, num_outputs)
            values,  # size (batch_size, time, 1)
            new_state,  # same siza as state
        ) = self.working_mem_net(torch.concat((x_t, m_t), dim=2), state)
        self.h_t_prev = h_t
        self._values = values
        self.mem_net.write(x_t, h_t)
        return actions, new_state

and in order to test it, I did dummy components

class DummyFeatNet(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.lin = nn.Linear(in_features=72 * 96 * 3, out_features=512)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        output = self.lin(x.flatten(start_dim=2))
        output = F.relu(output)
        return output


class DummyMem(nn.Module):
    def __init__(self) -> None:
        super().__init__()

    def forward(
        self, features: torch.Tensor, prev_hidden: torch.Tensor
    ) -> torch.Tensor:
        return torch.zeros((features.shape[0], features.shape[1], 1))

    def write(self, *args) -> None:
        pass


class DummyWorkingMem(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.lin1 = nn.Linear(in_features=513, out_features=8)
        self.lin2 = nn.Linear(in_features=513, out_features=1)

    def forward(self, x: torch.Tensor, state):
        actions = self.lin1(x)
        values = self.lin2(x)
        return x, actions, values, state

    def output_dim(self):
        return 512

    def initial_states(self):
        return torch.zeros(513)

I then tried running it through

algo = impala.Impala(env = "my_dmm_env", config={
    "env_config" : {
        "seed" : 123,
        "level_name" : "spot_diff_train",
    },
    "framework": "torch",
    "model": {
        "custom_model": AbstractMRA,
        "custom_model_config": {
            "feature_net": DummyFeatNet,
            "feature_net_config": {},
            "mem_net": DummyMem,
            "mem_net_config": {},
            "working_mem_net": DummyWorkingMem,
            "working_mem_net_config": {},
        },
    },
})

But I encounter problems… as is, I get this error :

ray/rllib/evaluation/postprocessing.py, line 313, in compute_bootstrap_value
-> sample_batch[SampleBatch.VALUES_BOOTSTRAPPED] = np.concatenate(
ValueError: all the input arrays must have same number of dimensions, but the array at index 0 has 3 dimension(s) and the array at index 1 has 1 dimension(s)

I read the code a bit and figured it was linked to the shapes of values returned by the model, and this is also linked with this post. So I tried to add a squeeze in the value_function :

    def value_function(self):
        if self._values is None:
            raise RuntimeError("must call forward() first")
        return self._values.squeeze(-1)  # Size (batch_size, time)

which then got me this error :

ray/rllib/algorithms/impala/vtrace_torch.py, line 310, in from_importance_weights
-> assert rho_rank == len(values.size())
AssertionError

and I can’t figure out what’s wrong.

In the end, I looked at the provided torch example on writing RNN classes (with a GRU), and if I follow well, with the size hints, we have :

def forward_rnn(self, input_dict, state, seq_lens):
            x = nn.functional.relu(self.fc1(input_dict["obs_flat"].float()))  # Size (batch, time, flat_feats)
            h_in = state[0].reshape(-1, self.rnn_hidden_dim)
            h = self.rnn(x, h_in)  # Size (batch, time, rnn_hidden_dim)
            q = self.fc2(h)
            self._cur_value = self.value_branch(h).squeeze(1) #Size (batch, time, 1)
            return q, [h]

If I follow well, squeezing 1 is useless and 2 should be squeezed instead. Also, I don’t get input_dict (of type SampleBatch) as an input as they do, but only one tensor of size (batch, time, flattened_dims), so I guess this documentation is quite unhelpful and misleading.

Does anyone know what’s wrong in my code ? (I can provide full stack trace if needed, and code for the environment)

Extra info, obs space is the following :
Dict('AvatarPosition': Box(-inf, inf, (3,), float32), 'RGB_INTERLEAVED': Box(0, 255, (72, 96, 3), uint8), 'Score': Box(-inf, inf, (), float64))