Here is a code template for a custom RLModule that processes a Dict observation space with a CNN for “pixels” and an MLP for “features”, concatenates their embeddings, and outputs policy/value heads for a MultiDiscrete action space. This template inherits from PPOTorchRLModule and implements ValueFunctionAPI, as required for PPO. The template handles channel-last to channel-first conversion for PyTorch Conv2d (source, custom_cnn_rl_module.py).
import torch
import torch.nn as nn
from ray.rllib.core.columns import Columns
from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import PPOTorchRLModule
from ray.rllib.core.rl_module.apis.value_function_api import ValueFunctionAPI
from ray.rllib.utils.annotations import override
class MyCustomRLModule(PPOTorchRLModule, ValueFunctionAPI):
@override(PPOTorchRLModule)
def setup(self):
# CNN for "pixels" (input: [B, 64, 64, 4], output: [B, cnn_out])
self.cnn = nn.Sequential(
nn.Conv2d(4, 32, kernel_size=8, stride=4), # [B, 32, 15, 15]
nn.ReLU(),
nn.Conv2d(32, 64, kernel_size=4, stride=2), # [B, 64, 6, 6]
nn.ReLU(),
nn.Flatten(), # [B, 64*6*6]
)
self.cnn_out_dim = 64 * 6 * 6
# MLP for "features" (input: [B, 9], output: [B, mlp_out])
self.mlp = nn.Sequential(
nn.Linear(9, 64),
nn.ReLU(),
nn.Linear(64, 64),
nn.ReLU(),
)
self.mlp_out_dim = 64
# Concatenated embedding size
concat_dim = self.cnn_out_dim + self.mlp_out_dim
# Policy head for MultiDiscrete([3, 2, 2, 6])
self.policy_head = nn.Linear(concat_dim, sum([3, 2, 2, 6]))
# Value head
self.value_head = nn.Linear(concat_dim, 1)
def _extract_obs(self, batch):
obs = batch[Columns.OBS]
pixels = obs["pixels"].float() # [B, 64, 64, 4]
features = obs["features"].float() # [B, 9]
# Convert to channel-first for CNN: [B, 4, 64, 64]
pixels = pixels.permute(0, 3, 1, 2)
return pixels, features
@override(PPOTorchRLModule)
def _forward_inference(self, batch, **kwargs):
pixels, features = self._extract_obs(batch)
cnn_out = self.cnn(pixels)
mlp_out = self.mlp(features)
concat = torch.cat([cnn_out, mlp_out], dim=-1)
logits = self.policy_head(concat)
return {Columns.ACTION_DIST_INPUTS: logits, Columns.VF_PREDS: self.value_head(concat).squeeze(-1)}
@override(PPOTorchRLModule)
def _forward_exploration(self, batch, **kwargs):
return self._forward_inference(batch, **kwargs)
@override(PPOTorchRLModule)
def _forward_train(self, batch, **kwargs):
return self._forward_inference(batch, **kwargs)
@override(ValueFunctionAPI)
def compute_values(self, batch, embeddings=None):
pixels, features = self._extract_obs(batch)
cnn_out = self.cnn(pixels)
mlp_out = self.mlp(features)
concat = torch.cat([cnn_out, mlp_out], dim=-1)
return self.value_head(concat).squeeze(-1)
# Integration with PPOConfig:
# config.rl_module(
# rl_module_spec=RLModuleSpec(
# module_class=MyCustomRLModule,
# model_config={},
# )
# )
Sources:
Hint: Mention @RunLLM in the post for followups.