Ok so I think I found a decent solution (probably not optimal but I am not familiar with the inner workings of RLlib).
The core of the issue is that when the catalog is initialized, it calls the method _determine_components_hook(), who then calls _get_encoder_config(), and the Dict observation space is not a case that is handled so it raises an exception.
So the best thing to do would be to override _get_encoder_config(), to return a custom EncoderConfig that would do the job. This is my next step, however I didn’t have the time to do it yet.
So what I did was basically short circuit the Catalog by overriding its init method and keeping the bare minimum, and then define a custom RLModule that is compliant with the Dict observation space defined in the env (here it just passes each part of the obs in two distincts MLPs, then concatenates their embeddings).
Working example
import functools
import ray
from ray import tune, air
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.core.rl_module.rl_module import RLModuleSpec
from ray.rllib.core.rl_module.torch import TorchRLModule
from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog
from ray.rllib.utils import override
from ray.tune.registry import register_env
import gymnasium as gym
import numpy as np
from typing import Any, Dict, Optional
from ray.rllib.utils.typing import TensorType
from ray.rllib.core.columns import Columns
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.core.rl_module.apis.value_function_api import ValueFunctionAPI
torch, nn = try_import_torch()
class DummyDictEnv(gym.Env):
"""
At each step, the observation is a dict of two obs that are between 0 and 1.
The agent has to choose for each obs a discrete action (0 or 1), and is rewarded
if it chooses the closest number, on both "axis".
"""
def __init__(self, duration=32, seed=None):
self.duration = duration
self.current_step = 0
# Basically MultiBinary(2) but unsupported by RLlib out of the box
self.action_space = gym.spaces.MultiDiscrete([2, 2], seed=seed)
self.observation_space = gym.spaces.Dict({
"foo": gym.spaces.Box(low=0, high=1, shape=(1,), seed=seed),
"bar": gym.spaces.Box(low=0, high=1, shape=(1,), seed=seed),
})
super().__init__()
def reset(self, *, seed=None, options=None):
self.last_obs = self.observation_space.sample()
self.current_step = 0
return self.last_obs, {}
def step(self, action):
reward = 0
good_foo_action = round(self.last_obs["foo"][0])
if action[0] == good_foo_action:
reward += 1
else:
reward -= 1
good_bar_action = round(self.last_obs["bar"][0])
if action[1] == good_bar_action:
reward += 1
else:
reward -= 1
self.last_obs = self.observation_space.sample()
done = False
if self.current_step >= self.duration:
done = True
self.current_step += 1
return self.last_obs, reward, done, False, {}
class TestDictTorchPPORLModule(TorchRLModule, ValueFunctionAPI):
@override(TorchRLModule)
def setup(self):
assert isinstance(self.action_space, gym.spaces.MultiDiscrete)
assert isinstance(self.observation_space, gym.spaces.Dict)
self.concat_embedding_size = 32
self.pi_out_features = np.prod(self.action_space.nvec)
self.foo_encoder = nn.Sequential(
nn.Linear(in_features=self.observation_space["foo"].shape[0], out_features=16, bias=True),
nn.ReLU(),
nn.Linear(in_features=16, out_features=self.concat_embedding_size, bias=True),
)
self.bar_encoder = nn.Sequential(
nn.Linear(in_features=self.observation_space["bar"].shape[0], out_features=16, bias=True),
nn.ReLU(),
nn.Linear(in_features=16, out_features=self.concat_embedding_size, bias=True),
)
self.my_pi = nn.Sequential(
nn.Linear(in_features=self.concat_embedding_size * 2, out_features=64, bias=True),
nn.ReLU(),
nn.Linear(in_features=64, out_features=self.pi_out_features, bias=True),
)
self.my_vf = nn.Sequential(
nn.Linear(in_features=self.concat_embedding_size * 2, out_features=64, bias=True),
nn.ReLU(),
nn.Linear(in_features=64, out_features=1, bias=True),
)
@override(TorchRLModule)
def _forward(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]:
"""Default forward pass (used for inference and exploration)."""
# Compute the basic 1D feature tensor (inputs to policy- and value-heads).
_, logits = self._compute_embeddings_and_logits(batch)
# Return features and logits as ACTION_DIST_INPUTS (categorical distribution).
return {
Columns.ACTION_DIST_INPUTS: logits,
}
@override(TorchRLModule)
def _forward_train(self, batch: Dict[str, Any], **kwargs) -> Dict[str, Any]:
# Compute the basic 1D feature tensor (inputs to policy- and value-heads).
embeddings, logits = self._compute_embeddings_and_logits(batch)
# Return features and logits as ACTION_DIST_INPUTS.
return {
Columns.ACTION_DIST_INPUTS: logits,
Columns.EMBEDDINGS: embeddings,
}
# We implement this RLModule as a ValueFunctionAPI RLModule, so it can be used
# by value-based methods like PPO or IMPALA.
@override(ValueFunctionAPI)
def compute_values(
self,
batch: Dict[str, Any],
embeddings: Optional[Any] = None,
) -> TensorType:
# Features not provided -> We need to compute them first.
if embeddings is None:
obs = batch[Columns.OBS]
foo_obs = obs["foo"]
bar_obs = obs["bar"]
foo_embeddings = self.foo_encoder(foo_obs)
bar_embeddings = self.bar_encoder(bar_obs)
embeddings = torch.cat((foo_embeddings, bar_embeddings), dim=-1) # Concatenate both outputs along feature dim
embeddings = torch.squeeze(embeddings, dim=[-1, -2])
return self.my_vf(embeddings).squeeze(-1) # Squeeze out last dimension (single node value head).
def _compute_embeddings_and_logits(self, batch):
obs = batch[Columns.OBS]
foo_obs = obs["foo"]
bar_obs = obs["bar"]
foo_embeddings = self.foo_encoder(foo_obs)
bar_embeddings = self.bar_encoder(bar_obs)
embeddings = torch.cat((foo_embeddings, bar_embeddings), dim=-1) # Concatenate both outputs along feature dim
logits = self.my_pi(embeddings)
return embeddings, logits
class TestDictCatalog(PPOCatalog):
def __init__(
self,
observation_space: gym.Space,
action_space: gym.Space,
model_config_dict: dict,
):
# Shortcircuit completely the normal Catalog __init__ to avoid the call to _get_encoder_config()
# which is not implemented yet for Dict spaces
self.observation_space = observation_space
self.action_space = action_space
self._action_dist_class_fn = functools.partial(
self._get_dist_cls_from_action_space, action_space=self.action_space
)
@classmethod
def _get_encoder_config(
cls,
observation_space: gym.Space,
**kwargs,
):
if isinstance(observation_space, gym.spaces.Dict):
# TODO here
raise NotImplementedError
else:
return super()._get_encoder_config(observation_space, **kwargs)
def env_creator_dummy(config):
env = DummyDictEnv()
return env
register_env('dummy_env', lambda config: env_creator_dummy(config))
config = (
PPOConfig()
.environment(
"dummy_env",
)
.env_runners(
num_env_runners=6,
num_envs_per_env_runner=1,
)
.learners(
num_learners=1,
num_gpus_per_learner=1,
)
.training(
lr=0.0003,
num_epochs=6,
vf_loss_coeff=0.01,
)
.rl_module(
rl_module_spec=RLModuleSpec(
module_class=TestDictTorchPPORLModule,
catalog_class=TestDictCatalog,
),
)
)
if __name__ == '__main__':
tuner = tune.Tuner(
trainable="PPO",
param_space=config,
run_config=air.RunConfig(
checkpoint_config=air.CheckpointConfig(checkpoint_at_end=True, checkpoint_frequency=20),
stop={"num_env_steps_sampled_lifetime": 200000},
),
)
context = ray.init()
results = tuner.fit()
We can see here that it trains pretty well on my dummy env, so I think it works !
I will try to keep this updated if I manage to create a real EncoderConfig to avoid hardcoding the NN architecture.
Best regards