How to use PPO with Dict observation space (pixels + features) in Ray 2.48.0?

Here is a code template for a custom RLModule that processes a Dict observation space with a CNN for “pixels” and an MLP for “features”, concatenates their embeddings, and outputs policy/value heads for a MultiDiscrete action space. This template inherits from PPOTorchRLModule and implements ValueFunctionAPI, as required for PPO. The template handles channel-last to channel-first conversion for PyTorch Conv2d (source, custom_cnn_rl_module.py).

import torch
import torch.nn as nn
from ray.rllib.core.columns import Columns
from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import PPOTorchRLModule
from ray.rllib.core.rl_module.apis.value_function_api import ValueFunctionAPI
from ray.rllib.utils.annotations import override

class MyCustomRLModule(PPOTorchRLModule, ValueFunctionAPI):
    @override(PPOTorchRLModule)
    def setup(self):
        # CNN for "pixels" (input: [B, 64, 64, 4], output: [B, cnn_out])
        self.cnn = nn.Sequential(
            nn.Conv2d(4, 32, kernel_size=8, stride=4),  # [B, 32, 15, 15]
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2), # [B, 64, 6, 6]
            nn.ReLU(),
            nn.Flatten(),                               # [B, 64*6*6]
        )
        self.cnn_out_dim = 64 * 6 * 6

        # MLP for "features" (input: [B, 9], output: [B, mlp_out])
        self.mlp = nn.Sequential(
            nn.Linear(9, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
        )
        self.mlp_out_dim = 64

        # Concatenated embedding size
        concat_dim = self.cnn_out_dim + self.mlp_out_dim

        # Policy head for MultiDiscrete([3, 2, 2, 6])
        self.policy_head = nn.Linear(concat_dim, sum([3, 2, 2, 6]))

        # Value head
        self.value_head = nn.Linear(concat_dim, 1)

    def _extract_obs(self, batch):
        obs = batch[Columns.OBS]
        pixels = obs["pixels"].float()  # [B, 64, 64, 4]
        features = obs["features"].float()  # [B, 9]
        # Convert to channel-first for CNN: [B, 4, 64, 64]
        pixels = pixels.permute(0, 3, 1, 2)
        return pixels, features

    @override(PPOTorchRLModule)
    def _forward_inference(self, batch, **kwargs):
        pixels, features = self._extract_obs(batch)
        cnn_out = self.cnn(pixels)
        mlp_out = self.mlp(features)
        concat = torch.cat([cnn_out, mlp_out], dim=-1)
        logits = self.policy_head(concat)
        return {Columns.ACTION_DIST_INPUTS: logits, Columns.VF_PREDS: self.value_head(concat).squeeze(-1)}

    @override(PPOTorchRLModule)
    def _forward_exploration(self, batch, **kwargs):
        return self._forward_inference(batch, **kwargs)

    @override(PPOTorchRLModule)
    def _forward_train(self, batch, **kwargs):
        return self._forward_inference(batch, **kwargs)

    @override(ValueFunctionAPI)
    def compute_values(self, batch, embeddings=None):
        pixels, features = self._extract_obs(batch)
        cnn_out = self.cnn(pixels)
        mlp_out = self.mlp(features)
        concat = torch.cat([cnn_out, mlp_out], dim=-1)
        return self.value_head(concat).squeeze(-1)

# Integration with PPOConfig:
# config.rl_module(
#     rl_module_spec=RLModuleSpec(
#         module_class=MyCustomRLModule,
#         model_config={},
#     )
# )

Sources:

Hint: Mention @RunLLM in the post for followups.