How to use PPO with Dict observation space (pixels + features) in Ray 2.48.0?

Context

I’m training a PPO agent in RLlib (Ray 2.48.0) with a custom Gymnasium environment that returns a Dict observation space containing both pixels and vector features:

import gymnasium as gym
import numpy as np

obs_space = gym.spaces.Dict({
    "pixels": gym.spaces.Box(0.0, 1.0, (84, 84, 4), dtype=np.float32),
    "features": gym.spaces.Box(-1.0, 1.0, (9,), dtype=np.float32),
})

The step() returns:

{
    "pixels": np.zeros((84, 84, 4), np.float32),
    "features": np.zeros(9, np.float32),
}

Problem
When running PPO with this env:

from ray.rllib.algorithms.ppo import PPOConfig
from ray.tune.registry import register_env

class DummyEnv(gym.Env):
    def __init__(self, cfg=None):
        self.observation_space = obs_space
        self.action_space = gym.spaces.Discrete(4)
    def reset(self, *, seed=None, options=None):
        return { "pixels": np.zeros((84,84,4), np.float32),
                 "features": np.zeros(9, np.float32) }, {}
    def step(self, action):
        return self.reset()[0], 0.0, False, False, {}

register_env("dummy", lambda cfg: DummyEnv())

cfg = (PPOConfig()
       .environment("dummy")
       .framework("torch"))

algo = cfg.build()

I get:

ValueError: No default encoder config for obs space=Dict('features': Box(-1.0, 1.0, (9,), float32),
'pixels': Box(0.0, 1.0, (84, 84, 4), float32)), lstm=False found.

Question

  • What is the recommended way in Ray 2.48.0 to handle such Dict spaces (CNN for "pixels" and MLP for "features", then concatenate)?

  • Do I need to manually define a custom RLModuleSpec / Catalog for this, or is there a built-in default?

  • If a manual config is required, could you provide a minimal example (Torch backend, PPO)?

System Info

  • Ray 2.48.0

  • Python 3.10

Workaround tested
Flattening the Dict works, but then "pixels" are treated as a flat vector and CNN processing is lost. Ideally, I’d like RLlib to auto-create a CNN branch for "pixels" and an MLP branch for "features".

The action masking example should give you the information you need on handling dictionary observation spaces in the way the system expects.

You might also want to look at some examples featuring custom encoder architectures. This one should cover the bases well enough.

Thank you for the pointers to action_masking_rlm.py and tiny_atari_cnn_rlm.py!

I have a follow-up question to clarify my current setup:

My observation space (defined in the Gymnasium env):

self.observation_space = spaces.Dict({
    "pixels": spaces.Box(low=0.0, high=1.0, shape=(84, 84, 4), dtype=np.float32),
    "features": spaces.Box(low=-1.0, high=1.0, shape=(9,), dtype=np.float32),
})

My training config (Ray 2.48.0, PPO):

config = (
    PPOConfig()
    .api_stack(enable_rl_module_and_learner=True, enable_env_runner_and_connector_v2=True)
    .environment(env="SuperTuxEnv")
    # ... no custom rl_module_spec or model config
)

My concern:
With the new API stack enabled and no explicit RLModule or encoder configuration, I’m unsure how RLlib handles my Dict observation space:

  1. Best case: RLlib’s Catalog automatically creates a CNN encoder for pixels (84x84x4) and an MLP encoder for features (9,), then concatenates them before the policy/value heads.

  2. Worst case: RLlib flattens the entire Dict into a 1D vector (84844 + 9 = 28,233 values) and feeds it into a single MLP - which would be very inefficient for image data.

My question:

  • Which behavior applies with the default PPOConfig + new API stack + Dict observation space?
  • If it’s the “flatten” case: Do I need to implement a custom RLModule (similar to action_masking_rlm.py) that manually splits the Dict and routes pixels through a CNN and features through an MLP?

I reviewed the examples you linked:

  • action_masking_rlm.py shows how to extract parts of a Dict obs in _preprocess_batch()
  • tiny_atari_cnn_rlm.py shows a custom CNN architecture for pixel observations

Would combining these patterns (custom RLModule with CNN for pixels + MLP for features) be the recommended approach for my use case?

Thanks!

With the new API stack enabled and no custom RLModule or model config, RLlib does not automatically create separate CNN/MLP branches for Dict observation spaces. By default, RLlib will attempt to flatten the Dict into a 1D vector and process it with a single MLP, which is inefficient for image data and not what you want for your setup. This is confirmed by user reports and error messages in recent discussions and issues, e.g., “No default encoder config for obs space=Dict(…)” and the need for custom handling of Dict spaces in the new stack (source, source, source).

To properly process your Dict obs (pixels via CNN, features via MLP), you should implement a custom RLModule that splits the Dict and routes each part through the appropriate encoder, then concatenates the results before the policy/value heads. This is the recommended approach and is demonstrated in the examples you referenced (action_masking_rlm.py and tiny_atari_cnn_rlm.py). Would you like a step-by-step breakdown or a code template for this pattern?

Sources:

Hint: Mention @RunLLM in the post for followups.

@RunLLM Thanks for confirming! A code template would be very helpful.

My Environment:

  • Ray: 2.52.1
  • PyTorch: 2.9.1+cu130
  • OS: Windows 10/11
  • Python: 3.10

Observation Space (from my Gymnasium env):

self.observation_space = spaces.Dict({
    # Stacked grayscale frames, channel-last: (H, W, K)
    "pixels": spaces.Box(
        low=0.0, high=1.0, 
        shape=(84, 84, 4),  # 84x84, 4 stacked frames
        dtype=np.float32
    ),
    # Normalized game state features
    "features": spaces.Box(
        low=-1.0, high=1.0, 
        shape=(9,),  # x, y, dx, dy, coins, etc.
        dtype=np.float32
    ),
})

Action Space:

self.action_space = spaces.MultiDiscrete([3, 2, 2, 6])
# [direction(left/idle/right), duck(no/yes), run(no/yes), jump_bucket(6 levels)]

Current PPOConfig (simplified):

from ray.rllib.algorithms.ppo import PPOConfig

config = (
    PPOConfig()
    .api_stack(
        enable_rl_module_and_learner=True,
        enable_env_runner_and_connector_v2=True,
    )
    .environment("supertux")  # registered via register_env()
    .framework("torch")
    .resources(num_gpus=0)
)

# Training params
config.rollout_fragment_length = 2048
config.train_batch_size = 4096
config.sgd_minibatch_size = 256
config.num_sgd_iter = 4
config.lr = 3e-4
config.gamma = 0.99
config.entropy_coeff = 0.05

What I need in the template:

  1. Custom RLModule class (preferably inheriting from PPOTorchRLModule)

    • setup() method defining:
      • CNN encoder for pixels (84x84x4, channel-last format)
      • MLP encoder for features (9,)
      • Concatenation layer
      • Policy head for MultiDiscrete([3, 2, 2, 6]) action space
      • Value head
  2. Forward methods:

    • _forward_inference()
    • _forward_exploration()
    • _forward_train()
    • Properly extracting batch[Columns.OBS]["pixels"] and batch[Columns.OBS]["features"]
  3. Integration with PPOConfig:

    config.rl_module_spec = RLModuleSpec(
        module_class=MyCustomRLModule,
        model_config={"...": "..."},
    )
    
  4. Note on channel format:
    My pixels are channel-last (84, 84, 4) - the template should handle permute(0, 3, 1, 2) for PyTorch Conv2d if needed.

  5. ValueFunctionAPI compliance (for PPO’s value function baseline)

Thanks for your help!

@RunLLM

Friendly reminder - I’m still waiting for the code template from my previous message.

In the meantime, I have two additional questions:


1. Does MARWIL + PPO require the same custom RLModule approach?

I’m planning to use MARWIL for pre-training (offline imitation learning from human demos) followed by PPO fine-tuning.

Since MARWIL uses the same RLModuleSpec architecture as PPO (MARWIL docs), I assume:

  • The Dict observation space problem applies equally to MARWIL
  • I would need to use the same custom RLModule (CNN for pixels + MLP for features) for both MARWIL and PPO
  • The RLModule checkpoint from MARWIL can be loaded into PPO for fine-tuning

Is this correct?


2. Does DreamerV3 handle Dict observation spaces natively?

According to the DreamerV3 docs:

“It can also be used in different environment types, including those with image- or vector based observations

The Atari 100k and DM Control Vision examples show DreamerV3 working with pixel observations.

My question:

  • Does DreamerV3’s built-in World Model automatically handle Dict observation spaces like {"pixels": Box(84,84,4), "features": Box(9,)}?
  • Or do I need a custom encoder configuration similar to the PPO case?

I noticed DreamerV3 has its own DreamerV3Config.model() method - is there a way to configure separate encoders for pixels vs. features there?


Environment reminder:

  • Ray: 2.52.1
  • PyTorch: 2.9.1+cu130
  • Python: 3.10
  • Observation Space: Dict({"pixels": Box(64,64,4), "features": Box(9,)})
  • Action Space: MultiDiscrete([3, 2, 2, 6])

Thanks!

  1. Yes, both MARWIL and PPO require the same custom RLModule approach for Dict observation spaces. MARWIL does not natively handle Dict spaces with separate encoders, so you must implement a custom RLModule (CNN for “pixels”, MLP for “features”) and use it for both algorithms. The RLModule checkpoint from MARWIL can be loaded into PPO for fine-tuning, as both use the same RLModuleSpec architecture (source, source, source).

  2. DreamerV3 does not natively handle Dict observation spaces with automatic per-key encoder routing. Its built-in World Model expects a single image or vector input, not a Dict, so you would need to implement a custom encoder or modify the model config to handle Dicts (e.g., by concatenating or custom processing). There is currently no built-in way in DreamerV3Config.model() to specify separate encoders for Dict keys (source).

Would you like a code template for the custom RLModule pattern?

Sources:

Hint: Mention @RunLLM in the post for followups.

@RunLLM Yes, please provide the code template for the custom RLModule!

To recap my requirements:

Observation Space:

spaces.Dict({
    "pixels": spaces.Box(0.0, 1.0, shape=(64, 64, 4), dtype=np.float32),
    "features": spaces.Box(-1.0, 1.0, shape=(9,), dtype=np.float32),
})

What I need in the template:

  1. Custom RLModule class (preferably inheriting from PPOTorchRLModule)

    • setup() method defining:
      • CNN encoder for pixels (84x84x4, channel-last format)
      • MLP encoder for features (9,)
      • Concatenation layer
      • Policy head for MultiDiscrete([3, 2, 2, 6]) action space
      • Value head
  2. Forward methods:

    • _forward_inference()
    • _forward_exploration()
    • _forward_train()
    • Properly extracting batch[Columns.OBS]["pixels"] and batch[Columns.OBS]["features"]
  3. Integration with PPOConfig:

    config.rl_module_spec = RLModuleSpec(
        module_class=MyCustomRLModule,
        model_config={"...": "..."},
    )
    
  4. Note on channel format:
    My pixels are channel-last (84, 84, 4) - the template should handle permute(0, 3, 1, 2) for PyTorch Conv2d if needed.

  5. ValueFunctionAPI compliance (for PPO’s value function baseline)

Thanks for your help!

Here is a code template for a custom RLModule that processes a Dict observation space with a CNN for “pixels” and an MLP for “features”, concatenates their embeddings, and outputs policy/value heads for a MultiDiscrete action space. This template inherits from PPOTorchRLModule and implements ValueFunctionAPI, as required for PPO. The template handles channel-last to channel-first conversion for PyTorch Conv2d (source, custom_cnn_rl_module.py).

import torch
import torch.nn as nn
from ray.rllib.core.columns import Columns
from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import PPOTorchRLModule
from ray.rllib.core.rl_module.apis.value_function_api import ValueFunctionAPI
from ray.rllib.utils.annotations import override

class MyCustomRLModule(PPOTorchRLModule, ValueFunctionAPI):
    @override(PPOTorchRLModule)
    def setup(self):
        # CNN for "pixels" (input: [B, 64, 64, 4], output: [B, cnn_out])
        self.cnn = nn.Sequential(
            nn.Conv2d(4, 32, kernel_size=8, stride=4),  # [B, 32, 15, 15]
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2), # [B, 64, 6, 6]
            nn.ReLU(),
            nn.Flatten(),                               # [B, 64*6*6]
        )
        self.cnn_out_dim = 64 * 6 * 6

        # MLP for "features" (input: [B, 9], output: [B, mlp_out])
        self.mlp = nn.Sequential(
            nn.Linear(9, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
        )
        self.mlp_out_dim = 64

        # Concatenated embedding size
        concat_dim = self.cnn_out_dim + self.mlp_out_dim

        # Policy head for MultiDiscrete([3, 2, 2, 6])
        self.policy_head = nn.Linear(concat_dim, sum([3, 2, 2, 6]))

        # Value head
        self.value_head = nn.Linear(concat_dim, 1)

    def _extract_obs(self, batch):
        obs = batch[Columns.OBS]
        pixels = obs["pixels"].float()  # [B, 64, 64, 4]
        features = obs["features"].float()  # [B, 9]
        # Convert to channel-first for CNN: [B, 4, 64, 64]
        pixels = pixels.permute(0, 3, 1, 2)
        return pixels, features

    @override(PPOTorchRLModule)
    def _forward_inference(self, batch, **kwargs):
        pixels, features = self._extract_obs(batch)
        cnn_out = self.cnn(pixels)
        mlp_out = self.mlp(features)
        concat = torch.cat([cnn_out, mlp_out], dim=-1)
        logits = self.policy_head(concat)
        return {Columns.ACTION_DIST_INPUTS: logits, Columns.VF_PREDS: self.value_head(concat).squeeze(-1)}

    @override(PPOTorchRLModule)
    def _forward_exploration(self, batch, **kwargs):
        return self._forward_inference(batch, **kwargs)

    @override(PPOTorchRLModule)
    def _forward_train(self, batch, **kwargs):
        return self._forward_inference(batch, **kwargs)

    @override(ValueFunctionAPI)
    def compute_values(self, batch, embeddings=None):
        pixels, features = self._extract_obs(batch)
        cnn_out = self.cnn(pixels)
        mlp_out = self.mlp(features)
        concat = torch.cat([cnn_out, mlp_out], dim=-1)
        return self.value_head(concat).squeeze(-1)

# Integration with PPOConfig:
# config.rl_module(
#     rl_module_spec=RLModuleSpec(
#         module_class=MyCustomRLModule,
#         model_config={},
#     )
# )

Sources:

Hint: Mention @RunLLM in the post for followups.