Using Dict observation space with custom RLModule

How severe does this issue affect your experience of using Ray?

  • High: It blocks me to complete my task.

Hello ! I’m trying to use RLLib new API on a custom multi-agent env (although there is only one agent for now).

My observation is an image with multiple channels, and I managed to train it using PPO and the default parameters and NN architecture (so CNN encoder and MLP heads for policy and value function).

However, I would now like to add more info in my observations, in the form of a vector (for instance some onehot encoding). What I want to do is something simple, along what SB3 does for dict obs spaces (sorry cannot put a link as new users are restricted with 2 links max per post).

So I updated my obs space to be a dict with two components, one for the image part to be processed by a CNN, and one for the vector part to be processed by an MLP. I did not find any existing examples in RLLib new API, so I tried to write a custom RLModule, based on the TinyAtariCNN example.

I wanted to do something like this in the _forward() method of RLModule:

obs_img = batch[Columns.OBS][“image”]
obs_vec = batch[Columns.OBS][“vector”]
then process stuff using torch.nn

However, an error occurs before anything on the RLModule is called :

(MultiAgentEnvRunner pid=164203) 2024-12-04 15:57:31,156 ERROR actor_manager.py:187 – Worker exception caught during apply(): all input arrays must have the same shape
(MultiAgentEnvRunner pid=164203) Traceback (most recent call last):
(MultiAgentEnvRunner pid=164203) File “/home/adrien/envs/rllib_env/lib/python3.10/site-packages/ray/rllib/utils/actor_manager.py”, line 183, in apply
(MultiAgentEnvRunner pid=164203) return func(self, *args, **kwargs)
(MultiAgentEnvRunner pid=164203) File “/home/adrien/envs/rllib_env/lib/python3.10/site-packages/ray/rllib/execution/rollout_ops.py”, line 110, in
(MultiAgentEnvRunner pid=164203) else (lambda w: (w.sample(**random_action_kwargs), w.get_metrics()))
(MultiAgentEnvRunner pid=164203) File “/home/adrien/envs/rllib_env/lib/python3.10/site-packages/ray/util/tracing/tracing_helper.py”, line 467, in _resume_span
(MultiAgentEnvRunner pid=164203) return method(self, *_args, **_kwargs)
(MultiAgentEnvRunner pid=164203) File “/home/adrien/envs/rllib_env/lib/python3.10/site-packages/ray/rllib/env/multi_agent_env_runner.py”, line 179, in sample
(MultiAgentEnvRunner pid=164203) samples = self._sample_timesteps(
(MultiAgentEnvRunner pid=164203) File “/home/adrien/envs/rllib_env/lib/python3.10/site-packages/ray/util/tracing/tracing_helper.py”, line 467, in _resume_span
(MultiAgentEnvRunner pid=164203) return method(self, *_args, **_kwargs)
(MultiAgentEnvRunner pid=164203) File “/home/adrien/envs/rllib_env/lib/python3.10/site-packages/ray/rllib/env/multi_agent_env_runner.py”, line 377, in _sample_timesteps
(MultiAgentEnvRunner pid=164203) self._episode.finalize(drop_zero_len_single_agent_episodes=True)
(MultiAgentEnvRunner pid=164203) File “/home/adrien/envs/rllib_env/lib/python3.10/site-packages/ray/rllib/env/multi_agent_episode.py”, line 794, in finalize
(MultiAgentEnvRunner pid=164203) agent_eps.finalize()
(MultiAgentEnvRunner pid=164203) File “/home/adrien/envs/rllib_env/lib/python3.10/site-packages/ray/rllib/env/single_agent_episode.py”, line 576, in finalize
(MultiAgentEnvRunner pid=164203) self.observations.finalize()
(MultiAgentEnvRunner pid=164203) File “/home/adrien/envs/rllib_env/lib/python3.10/site-packages/ray/rllib/env/utils/infinite_lookback_buffer.py”, line 161, in finalize
(MultiAgentEnvRunner pid=164203) self.data = batch(self.data)
(MultiAgentEnvRunner pid=164203) File “/home/adrien/envs/rllib_env/lib/python3.10/site-packages/ray/rllib/utils/spaces/space_utils.py”, line 373, in batch
(MultiAgentEnvRunner pid=164203) ret = tree.map_structure(lambda *s: np_func(s, axis=0), *list_of_structs)
(MultiAgentEnvRunner pid=164203) File “/home/adrien/envs/rllib_env/lib/python3.10/site-packages/tree/init.py”, line 435, in map_structure
(MultiAgentEnvRunner pid=164203) [func(*args) for args in zip(*map(flatten, structures))])
(MultiAgentEnvRunner pid=164203) File “/home/adrien/envs/rllib_env/lib/python3.10/site-packages/tree/init.py”, line 435, in
(MultiAgentEnvRunner pid=164203) [func(*args) for args in zip(*map(flatten, structures))])
(MultiAgentEnvRunner pid=164203) File “/home/adrien/envs/rllib_env/lib/python3.10/site-packages/ray/rllib/utils/spaces/space_utils.py”, line 373, in
(MultiAgentEnvRunner pid=164203) ret = tree.map_structure(lambda *s: np_func(s, axis=0), *list_of_structs)
(MultiAgentEnvRunner pid=164203) File “/home/adrien/envs/rllib_env/lib/python3.10/site-packages/numpy/core/shape_base.py”, line 449, in stack
(MultiAgentEnvRunner pid=164203) raise ValueError(‘all input arrays must have the same shape’)
(MultiAgentEnvRunner pid=164203) ValueError: all input arrays must have the same shape

The two closest things I could find related to my problem are the ComplexInputNet for the old API stack but this is the old API that I’m not familiar with, and the FlattenObservations connector , but this seems quite complex and I don’t want to flatten anything, just pass it as is to my RLModule and then process it inside.

Does anyone have an idea on how I can handle this ? I’m sure this is a common problem and I might have overlooked some important information, but at this point I feel like I have read the docs and relevant source code in and out, and I’m quite lost… : (

Thanks in advance !

Same final error message as described in Training Action Masked PPO - ValueError: all input arrays must have the same shape ok False , but different setting, i.e. here multi-agent, and there single-agent action masking. Recommend to open GH issue, the RLModule API is actively developed at the moment.

@adrienJeg thanks for raising this. Any chance you can provide a reproducable example?

Hi, yes I did this quickly, it could be even simpler because the issue is probably just that Dict observation spaces are not supported out of the box :

import ray
from ray import tune, air
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.core.rl_module.default_model_config import DefaultModelConfig

import gymnasium as gym

from ray.tune.registry import register_env


class DummyDictEnv(gym.Env):
    def __init__(self, duration=32, seed=None):
        self.duration = duration
        self.current_step = 0
        self.action_space = gym.spaces.Discrete(2, seed=seed)
        self.observation_space = gym.spaces.Dict({
            "a": gym.spaces.Box(low=0, high=1, shape=(1,), seed=seed),
            "b": gym.spaces.Box(low=0, high=1, shape=(1,), seed=seed),
        })

        super().__init__()

    def reset(self, *, seed=None, options=None):
        self.last_obs = self.observation_space.sample()
        self.current_step = 0
        return self.last_obs, {}
    
    def step(self, action):
        if action >= self.last_obs["a"][0]:
            reward = 1
        else:
            reward = -1
        self.last_obs = self.observation_space.sample()
        
        done = False
        if self.current_step >= self.duration:
            done = True

        self.current_step += 1
        
        return self.last_obs, reward, done, False, {}


def env_creator_dummy(config):
    env = DummyDictEnv()
    return env

register_env('dummy_env', lambda config: env_creator_dummy(config))

config = (
    PPOConfig()
    .environment(
        "dummy_env",
    )
    .training(
        lr=0.0003,
        num_epochs=6,
        vf_loss_coeff=0.01,
    )
    .rl_module(
        model_config=DefaultModelConfig(
            fcnet_hiddens=[32],
            fcnet_activation="linear",
            vf_share_layers=True,
        ),
    )
)

if __name__ == '__main__':

    tuner = tune.Tuner(
        trainable="PPO",
        param_space=config,
        run_config=air.RunConfig(
            checkpoint_config=air.CheckpointConfig(checkpoint_at_end=True, checkpoint_frequency=20),
            stop={"num_env_steps_sampled_lifetime": 10000},
        ),
    )

    context = ray.init()
    results = tuner.fit()

I speculate that it is due to the way RLlib batches observations, you need a way to “concatenate” all tensors together and this is not supported directly, you need to add a connector. I will try to do something quick and dirty, like putting all my scalars obs inside a dedicated channel of my image obs, then processing it inside my RLModule (I think this should work ?)

Have a good day !

1 Like

@adrienJeg thanks for the repro. We can reproduce and work on a solution now. We will keep you update on it

1 Like

Ok so I think I found a decent solution (probably not optimal but I am not familiar with the inner workings of RLlib).

The core of the issue is that when the catalog is initialized, it calls the method _determine_components_hook(), who then calls _get_encoder_config(), and the Dict observation space is not a case that is handled so it raises an exception.

So the best thing to do would be to override _get_encoder_config(), to return a custom EncoderConfig that would do the job. This is my next step, however I didn’t have the time to do it yet.

So what I did was basically short circuit the Catalog by overriding its init method and keeping the bare minimum, and then define a custom RLModule that is compliant with the Dict observation space defined in the env (here it just passes each part of the obs in two distincts MLPs, then concatenates their embeddings).

Working example
import functools

import ray
from ray import tune, air
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.core.rl_module.rl_module import RLModuleSpec
from ray.rllib.core.rl_module.torch import TorchRLModule
from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog
from ray.rllib.utils import override
from ray.tune.registry import register_env

import gymnasium as gym
import numpy as np

from typing import Any, Dict, Optional
from ray.rllib.utils.typing import TensorType

from ray.rllib.core.columns import Columns
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_torch

from ray.rllib.core.rl_module.apis.value_function_api import ValueFunctionAPI


torch, nn = try_import_torch()


class DummyDictEnv(gym.Env):
    """
    At each step, the observation is a dict of two obs that are between 0 and 1.
    The agent has to choose for each obs a discrete action (0 or 1), and is rewarded
    if it chooses the closest number, on both "axis".
    """
    def __init__(self, duration=32, seed=None):
        self.duration = duration
        self.current_step = 0

        # Basically MultiBinary(2) but unsupported by RLlib out of the box
        self.action_space = gym.spaces.MultiDiscrete([2, 2], seed=seed)
        self.observation_space = gym.spaces.Dict({
            "foo": gym.spaces.Box(low=0, high=1, shape=(1,), seed=seed),
            "bar": gym.spaces.Box(low=0, high=1, shape=(1,), seed=seed),
        })

        super().__init__()

    def reset(self, *, seed=None, options=None):
        self.last_obs = self.observation_space.sample()
        self.current_step = 0
        return self.last_obs, {}
    
    def step(self, action):
        reward = 0

        good_foo_action = round(self.last_obs["foo"][0])
        if action[0] == good_foo_action:
            reward += 1
        else:
            reward -= 1

        good_bar_action = round(self.last_obs["bar"][0])
        if action[1] == good_bar_action:
            reward += 1
        else:
            reward -= 1

        self.last_obs = self.observation_space.sample()
        
        done = False
        if self.current_step >= self.duration:
            done = True

        self.current_step += 1
        
        return self.last_obs, reward, done, False, {}


class TestDictTorchPPORLModule(TorchRLModule, ValueFunctionAPI):
    @override(TorchRLModule)
    def setup(self):
        assert isinstance(self.action_space, gym.spaces.MultiDiscrete)
        assert isinstance(self.observation_space, gym.spaces.Dict)

        self.concat_embedding_size = 32
        self.pi_out_features = np.prod(self.action_space.nvec)

        self.foo_encoder = nn.Sequential(
            nn.Linear(in_features=self.observation_space["foo"].shape[0], out_features=16, bias=True),
            nn.ReLU(),
            nn.Linear(in_features=16, out_features=self.concat_embedding_size, bias=True),
        )

        self.bar_encoder = nn.Sequential(
            nn.Linear(in_features=self.observation_space["bar"].shape[0], out_features=16, bias=True),
            nn.ReLU(),
            nn.Linear(in_features=16, out_features=self.concat_embedding_size, bias=True),
        )

        self.my_pi = nn.Sequential(
            nn.Linear(in_features=self.concat_embedding_size * 2, out_features=64, bias=True),
            nn.ReLU(),
            nn.Linear(in_features=64, out_features=self.pi_out_features, bias=True),
        )

        self.my_vf = nn.Sequential(
            nn.Linear(in_features=self.concat_embedding_size * 2, out_features=64, bias=True),
            nn.ReLU(),
            nn.Linear(in_features=64, out_features=1, bias=True),
        )

    @override(TorchRLModule)
    def _forward(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]:
        """Default forward pass (used for inference and exploration)."""
        # Compute the basic 1D feature tensor (inputs to policy- and value-heads).
        _, logits = self._compute_embeddings_and_logits(batch)
        # Return features and logits as ACTION_DIST_INPUTS (categorical distribution).
        return {
            Columns.ACTION_DIST_INPUTS: logits,
        }

    @override(TorchRLModule)
    def _forward_train(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]:
        # Compute the basic 1D feature tensor (inputs to policy- and value-heads).
        embeddings, logits = self._compute_embeddings_and_logits(batch)
        # Return features and logits as ACTION_DIST_INPUTS.
        return {
            Columns.ACTION_DIST_INPUTS: logits,
            Columns.EMBEDDINGS: embeddings,
        }
      
    # We implement this RLModule as a ValueFunctionAPI RLModule, so it can be used
    # by value-based methods like PPO or IMPALA.
    @override(ValueFunctionAPI)
    def compute_values(
        self,
        batch: Dict[str, Any],
        embeddings: Optional[Any] = None,
    ) -> TensorType:
        # Features not provided -> We need to compute them first.
        if embeddings is None:
            obs = batch[Columns.OBS]
            foo_obs = obs["foo"]
            bar_obs = obs["bar"]
            foo_embeddings = self.foo_encoder(foo_obs)
            bar_embeddings = self.bar_encoder(bar_obs)
            embeddings = torch.cat((foo_embeddings, bar_embeddings), dim=-1)  # Concatenate both outputs along feature dim
            embeddings = torch.squeeze(embeddings, dim=[-1, -2])
        
        return self.my_vf(embeddings).squeeze(-1)  # Squeeze out last dimension (single node value head).
    
    def _compute_embeddings_and_logits(self, batch):
        obs = batch[Columns.OBS]
        foo_obs = obs["foo"]
        bar_obs = obs["bar"]
        foo_embeddings = self.foo_encoder(foo_obs)
        bar_embeddings = self.bar_encoder(bar_obs)
        embeddings = torch.cat((foo_embeddings, bar_embeddings), dim=-1)  # Concatenate both outputs along feature dim
        logits = self.my_pi(embeddings)

        return embeddings, logits
    

class TestDictCatalog(PPOCatalog):
    def __init__(
        self,
        observation_space: gym.Space,
        action_space: gym.Space,
        model_config_dict: dict,
    ):
        # Shortcircuit completely the normal Catalog __init__ to avoid the call to _get_encoder_config() 
        # which is not implemented yet for Dict spaces
        self.observation_space = observation_space
        self.action_space = action_space

        self._action_dist_class_fn = functools.partial(
            self._get_dist_cls_from_action_space, action_space=self.action_space
        )

    @classmethod
    def _get_encoder_config(
        cls,
        observation_space: gym.Space,
        **kwargs,
    ):
        if isinstance(observation_space, gym.spaces.Dict):
            # TODO here
            raise NotImplementedError
        else:
            return super()._get_encoder_config(observation_space, **kwargs)


def env_creator_dummy(config):
    env = DummyDictEnv()
    return env

register_env('dummy_env', lambda config: env_creator_dummy(config))

config = (
    PPOConfig()
    .environment(
        "dummy_env",
    )
    .env_runners(
        num_env_runners=6,
        num_envs_per_env_runner=1,
    )
    .learners(
        num_learners=1,
        num_gpus_per_learner=1,
    )
    .training(
        lr=0.0003,
        num_epochs=6,
        vf_loss_coeff=0.01,
    )
    .rl_module(
        rl_module_spec=RLModuleSpec(
            module_class=TestDictTorchPPORLModule,
            catalog_class=TestDictCatalog,
        ),
    )
)

if __name__ == '__main__':

    tuner = tune.Tuner(
        trainable="PPO",
        param_space=config,
        run_config=air.RunConfig(
            checkpoint_config=air.CheckpointConfig(checkpoint_at_end=True, checkpoint_frequency=20),
            stop={"num_env_steps_sampled_lifetime": 200000},
        ),
    )

    context = ray.init()
    results = tuner.fit()

We can see here that it trains pretty well on my dummy env, so I think it works !

I will try to keep this updated if I manage to create a real EncoderConfig to avoid hardcoding the NN architecture.

Best regards