How to train custom models with `SampleBatch.INFOS`

I’m trying to implement centralized critic (e.g. MADDPG or MAPPO algorithm) without using custom policy with postprocess_trajectory. So I can switch my base policy algorithm with tune.run("algo") without writing new policy classes.

I want to use my custom model and extract the environment’s global state from SampleBatch.INFOS. For example:

from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.torch.recurrent_net import RecurrentNetwork as TorchRNNModel
from ray.rllib.policy.sample_batch import SampleBatch
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.torch_utils import sequence_mask

torch, nn = try_import_torch()

class CustomModel(TorchRNNModel, nn.Module)
    def __init__(self,
                 obs_space,
                 action_space,
                 num_outputs,
                 model_config,
                 name,
                 **kwargs):
        ...
        self.view_requirements[SampleBatch.INFOS] = ViewRequirement()
        self._global_state = None

    @override(TorchRNN):
    def forward(self, input_dict, state, seq_lens):
        infos = input_dict[SampleBatch.INFOS]
        self._global_state = extract_from_infos(infos)
        return super().forward(input_dict, state, seq_lens)

    def value_function(self):
        # Centralized Critic
        assert self._global_state is not None, "must call forward() first"
        return torch.reshape(self.value_branch(self._global_state), [-1])

I found that, in sampling, evaluation and postprocessing, the value for key SampleBatch.INFOS are numpy.ndarray with zeros:

self.training -> False
input_dict[SampleBatch.INFOS] -> np.array([0.0, ..., 0.0], dtype=np.float32)

In training:

self.training -> True
input_dict[SampleBatch.INFOS] -> np.array([{...}, ..., {...}], dtype=np.dtype("O"))
# `INFOS` are for `NEXT_OBS` instead of `OBS`

I wonder what’s the best practice for training and sampling with SampleBatch.INFOS? Thanks very much!

BTW, there is also a bug detailed in ray-project/ray#21991: [RLlib][Bug] Wrong number of samples in batch[SampleBatch.INFOS] in RNN cases.

@avnishn any thoughts here?

Hi @XuehaiPan

I am not sure if it would work for you but the way I would approach this is by using a Dict observation space and putting the state as one of the keys in the observation dictionary.

The other thing you could try is explicity ladding info to the ViewRequirements for both compute actions and training.

1 Like

Thanks for the reply!

I am not sure if it would work for you but the way I would approach this is by using a Dict observation space and putting the state as one of the keys in the observation dictionary.

I think would be the solution since we don’t have info in the initial step (env.reset()).

The other thing you could try is explicity ladding info to the ViewRequirements for both compute actions and training.

I tried adding ViewRequirement(SampleBatch.INFOS), but I got a numpy array with zeros in sampler’s sample batch (expect an array of dicts). This makes sense because info is returned with next_obs. The info data hasn’t been added to the buffer yet when calling policy.compute_actions (or policy.model.forward). Then the batch gives zeros.

I also tried ViewRequirement(SampleBatch.INFOS, shift=-1) and ViewRequirement(SampleBatch.INFOS, shift=+1). I got IndexErrors in remote workers.

I feel like INFO is the right way to go about this.
if you put the global state in OBS under a Dict input space, that data will get feeded into the NN, or you will have to clean it up yourself, which is extra work.
I vaguely remember that we may need to configure a whitelist somewhere so RLlib would know to keep your special state under INFO.

any chance you can provide a complete small script to demonstrate your use case, and we can debug a bit on our end?

thanks.

Hi @gjoliver,

The issue with this approach as @XuehaiPan points out is that the gym api for reset does not return info so it will always be missing in the first step. @XuehaiPan mentioned using a custom model in the first post. If you are writing a custom model then incorporating a dict space is not really any extra work.

I see. Get it.
actually in a second thought, maybe we should question a bit whether a global state on the model class is the right thing to do. E.g., if a model is serving multiple rollout env etc.