Using Dict observation space with custom RLModule

Ok so I think I found a decent solution (probably not optimal but I am not familiar with the inner workings of RLlib).

The core of the issue is that when the catalog is initialized, it calls the method _determine_components_hook(), who then calls _get_encoder_config(), and the Dict observation space is not a case that is handled so it raises an exception.

So the best thing to do would be to override _get_encoder_config(), to return a custom EncoderConfig that would do the job. This is my next step, however I didn’t have the time to do it yet.

So what I did was basically short circuit the Catalog by overriding its init method and keeping the bare minimum, and then define a custom RLModule that is compliant with the Dict observation space defined in the env (here it just passes each part of the obs in two distincts MLPs, then concatenates their embeddings).

Working example
import functools

import ray
from ray import tune, air
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.core.rl_module.rl_module import RLModuleSpec
from ray.rllib.core.rl_module.torch import TorchRLModule
from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog
from ray.rllib.utils import override
from ray.tune.registry import register_env

import gymnasium as gym
import numpy as np

from typing import Any, Dict, Optional
from ray.rllib.utils.typing import TensorType

from ray.rllib.core.columns import Columns
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_torch

from ray.rllib.core.rl_module.apis.value_function_api import ValueFunctionAPI


torch, nn = try_import_torch()


class DummyDictEnv(gym.Env):
    """
    At each step, the observation is a dict of two obs that are between 0 and 1.
    The agent has to choose for each obs a discrete action (0 or 1), and is rewarded
    if it chooses the closest number, on both "axis".
    """
    def __init__(self, duration=32, seed=None):
        self.duration = duration
        self.current_step = 0

        # Basically MultiBinary(2) but unsupported by RLlib out of the box
        self.action_space = gym.spaces.MultiDiscrete([2, 2], seed=seed)
        self.observation_space = gym.spaces.Dict({
            "foo": gym.spaces.Box(low=0, high=1, shape=(1,), seed=seed),
            "bar": gym.spaces.Box(low=0, high=1, shape=(1,), seed=seed),
        })

        super().__init__()

    def reset(self, *, seed=None, options=None):
        self.last_obs = self.observation_space.sample()
        self.current_step = 0
        return self.last_obs, {}
    
    def step(self, action):
        reward = 0

        good_foo_action = round(self.last_obs["foo"][0])
        if action[0] == good_foo_action:
            reward += 1
        else:
            reward -= 1

        good_bar_action = round(self.last_obs["bar"][0])
        if action[1] == good_bar_action:
            reward += 1
        else:
            reward -= 1

        self.last_obs = self.observation_space.sample()
        
        done = False
        if self.current_step >= self.duration:
            done = True

        self.current_step += 1
        
        return self.last_obs, reward, done, False, {}


class TestDictTorchPPORLModule(TorchRLModule, ValueFunctionAPI):
    @override(TorchRLModule)
    def setup(self):
        assert isinstance(self.action_space, gym.spaces.MultiDiscrete)
        assert isinstance(self.observation_space, gym.spaces.Dict)

        self.concat_embedding_size = 32
        self.pi_out_features = np.prod(self.action_space.nvec)

        self.foo_encoder = nn.Sequential(
            nn.Linear(in_features=self.observation_space["foo"].shape[0], out_features=16, bias=True),
            nn.ReLU(),
            nn.Linear(in_features=16, out_features=self.concat_embedding_size, bias=True),
        )

        self.bar_encoder = nn.Sequential(
            nn.Linear(in_features=self.observation_space["bar"].shape[0], out_features=16, bias=True),
            nn.ReLU(),
            nn.Linear(in_features=16, out_features=self.concat_embedding_size, bias=True),
        )

        self.my_pi = nn.Sequential(
            nn.Linear(in_features=self.concat_embedding_size * 2, out_features=64, bias=True),
            nn.ReLU(),
            nn.Linear(in_features=64, out_features=self.pi_out_features, bias=True),
        )

        self.my_vf = nn.Sequential(
            nn.Linear(in_features=self.concat_embedding_size * 2, out_features=64, bias=True),
            nn.ReLU(),
            nn.Linear(in_features=64, out_features=1, bias=True),
        )

    @override(TorchRLModule)
    def _forward(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]:
        """Default forward pass (used for inference and exploration)."""
        # Compute the basic 1D feature tensor (inputs to policy- and value-heads).
        _, logits = self._compute_embeddings_and_logits(batch)
        # Return features and logits as ACTION_DIST_INPUTS (categorical distribution).
        return {
            Columns.ACTION_DIST_INPUTS: logits,
        }

    @override(TorchRLModule)
    def _forward_train(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]:
        # Compute the basic 1D feature tensor (inputs to policy- and value-heads).
        embeddings, logits = self._compute_embeddings_and_logits(batch)
        # Return features and logits as ACTION_DIST_INPUTS.
        return {
            Columns.ACTION_DIST_INPUTS: logits,
            Columns.EMBEDDINGS: embeddings,
        }
      
    # We implement this RLModule as a ValueFunctionAPI RLModule, so it can be used
    # by value-based methods like PPO or IMPALA.
    @override(ValueFunctionAPI)
    def compute_values(
        self,
        batch: Dict[str, Any],
        embeddings: Optional[Any] = None,
    ) -> TensorType:
        # Features not provided -> We need to compute them first.
        if embeddings is None:
            obs = batch[Columns.OBS]
            foo_obs = obs["foo"]
            bar_obs = obs["bar"]
            foo_embeddings = self.foo_encoder(foo_obs)
            bar_embeddings = self.bar_encoder(bar_obs)
            embeddings = torch.cat((foo_embeddings, bar_embeddings), dim=-1)  # Concatenate both outputs along feature dim
            embeddings = torch.squeeze(embeddings, dim=[-1, -2])
        
        return self.my_vf(embeddings).squeeze(-1)  # Squeeze out last dimension (single node value head).
    
    def _compute_embeddings_and_logits(self, batch):
        obs = batch[Columns.OBS]
        foo_obs = obs["foo"]
        bar_obs = obs["bar"]
        foo_embeddings = self.foo_encoder(foo_obs)
        bar_embeddings = self.bar_encoder(bar_obs)
        embeddings = torch.cat((foo_embeddings, bar_embeddings), dim=-1)  # Concatenate both outputs along feature dim
        logits = self.my_pi(embeddings)

        return embeddings, logits
    

class TestDictCatalog(PPOCatalog):
    def __init__(
        self,
        observation_space: gym.Space,
        action_space: gym.Space,
        model_config_dict: dict,
    ):
        # Shortcircuit completely the normal Catalog __init__ to avoid the call to _get_encoder_config() 
        # which is not implemented yet for Dict spaces
        self.observation_space = observation_space
        self.action_space = action_space

        self._action_dist_class_fn = functools.partial(
            self._get_dist_cls_from_action_space, action_space=self.action_space
        )

    @classmethod
    def _get_encoder_config(
        cls,
        observation_space: gym.Space,
        **kwargs,
    ):
        if isinstance(observation_space, gym.spaces.Dict):
            # TODO here
            raise NotImplementedError
        else:
            return super()._get_encoder_config(observation_space, **kwargs)


def env_creator_dummy(config):
    env = DummyDictEnv()
    return env

register_env('dummy_env', lambda config: env_creator_dummy(config))

config = (
    PPOConfig()
    .environment(
        "dummy_env",
    )
    .env_runners(
        num_env_runners=6,
        num_envs_per_env_runner=1,
    )
    .learners(
        num_learners=1,
        num_gpus_per_learner=1,
    )
    .training(
        lr=0.0003,
        num_epochs=6,
        vf_loss_coeff=0.01,
    )
    .rl_module(
        rl_module_spec=RLModuleSpec(
            module_class=TestDictTorchPPORLModule,
            catalog_class=TestDictCatalog,
        ),
    )
)

if __name__ == '__main__':

    tuner = tune.Tuner(
        trainable="PPO",
        param_space=config,
        run_config=air.RunConfig(
            checkpoint_config=air.CheckpointConfig(checkpoint_at_end=True, checkpoint_frequency=20),
            stop={"num_env_steps_sampled_lifetime": 200000},
        ),
    )

    context = ray.init()
    results = tuner.fit()

We can see here that it trains pretty well on my dummy env, so I think it works !

I will try to keep this updated if I manage to create a real EncoderConfig to avoid hardcoding the NN architecture.

Best regards

1 Like