1. Severity of the issue: (select one)
High: Completely blocks me.
2. Environment:
- Ray version: 2.47.1
- Python version: 3.11.7
- OS: Windows 10
3. What happened vs. what you expected:
Hello everyone,
I’m trying to implement a multi-agent learning setup with action masking. I previously had this working using the old RLlib API stack, but now I’m upgrading to the new API and running into some trouble understanding how everything fits together.
From what I understood (though I might be wrong), I created a custom model that should handle action masking and process the observation in two separate branches:
- A CNN to process the image part of the observation
- A standard MLP to process the remaining vector-based observation
import torch
import torch.nn as nn
from ray.rllib.core.columns import Columns
from ray.rllib.core.rl_module.torch import TorchRLModule
from ray.rllib.utils.typing import TensorType
from typing import Any, Dict, Optional
class MaskedModel(nn.Module):
def __init__(self, num_outputs):
super().__init__()
# CNN encoder for truncated_image (assumes [3, 32, 32])
self.cnn = nn.Sequential(
nn.Conv2d(3, 32, kernel_size=4, stride=2, padding=1), # -> (32, 16, 16)
nn.ReLU(),
nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1), # -> (64, 8, 8)
nn.ReLU(),
nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1), # -> (128, 4, 4)
nn.ReLU(),
nn.Flatten() # -> 128 * 4 * 4 = 2048
)
self.image_out_size = 2048
# Fully connected encoder for vector input
self.vector_input_size = (
100 + 1 + 1 # softmax + truncated + prev_agent_action
)
self.vector_encoder = nn.Sequential(
nn.Linear(self.vector_input_size, 64),
nn.ReLU(),
nn.Linear(64, 64),
nn.ReLU(),
)
self.fusion_size = self.image_out_size + 64
# Maps the concatenated vector to action logits.
self.policy_head = nn.Linear(self.fusion_size, num_outputs)
# Maps it to a scalar value (for PPO critic).
self._value_branch = nn.Linear(self.fusion_size, 1)
def forward(self, fused_features):
"""Forward pass that returns both logits and value."""
logits = self.policy_head(fused_features)
value = self._value_branch(fused_features).squeeze(-1)
return logits, value
class MaskedPPOModule(TorchRLModule):
def setup(self):
self.model = MaskedModel(
num_outputs=self.action_space.n, # Use the action space to infer the number of output nodes.
)
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model.to(self.device)
def _forward(self, batch, **kwargs):
obs = batch["obs"]
action_mask = obs["action_mask"].float().to(self.device)
fused = self.concat(obs)
logits, value = self.model(fused)
# Apply action mask
inf_mask = torch.clamp(torch.log(action_mask), min=-1e10)
masked_logits = logits + inf_mask
# see https://discuss.ray.io/t/keyerror-advantages-when-training-ppo-with-custom-model-in-rllib/21876/4
return {
Columns.ACTION_DIST_INPUTS: masked_logits,
Columns.VF_PREDS: value,
# These will be filled in by the PPO algorithm during training
Columns.ADVANTAGES: torch.zeros_like(value),
Columns.VALUE_TARGETS: torch.zeros_like(value),
}
def _forward_train(self, batch, **kwargs):
obs = batch["obs"]
action_mask = obs["action_mask"].float().to(self.device)
fused = self.concat(obs)
logits, value = self.model(fused)
# Apply action mask
inf_mask = torch.clamp(torch.log(action_mask), min=-1e10)
masked_logits = logits + inf_mask
outs = super()._forward_train(batch, **kwargs)
outs[Columns.ACTION_DIST_INPUTS] = masked_logits
return outs
# Get the joint feature vector.
def concat(self, obs):
# Ensure all tensors are on the correct self.device
image = obs["truncated_image"].float().to(self.device) / 255.0
softmax = obs["softmax"].float().to(self.device)
truncated = obs["truncated"].float().unsqueeze(-1).to(self.device)
prev_action = obs["prev_agent_action"].unsqueeze(-1).float().to(self.device)
# Forward pass through CNN and vector encoder
img_out = self.model.cnn(image)
vector_input = torch.cat([softmax, truncated, prev_action], dim=-1)
vector_out = self.model.vector_encoder(vector_input)
# Concatenate visual and vector features
return torch.cat([img_out, vector_out], dim=-1)
def compute_values(
self, batch: Dict[str, Any], embeddings: Optional[Any] = None
) -> TensorType:
"""Compute value estimates for the critic head."""
obs = batch[Columns.OBS]
fused = self.concat(obs)
return self.model._value_branch(fused).squeeze(-1)
It took quite some effort to figure out what the _forward()
method is supposed to return, so I’m not entirely sure whether I’m missing something or overwriting an important part of the pipeline.
My main problem now is that PPO doesn’t seem to learn anything—the output appears completely random. I’ve tried tweaking many things, but I’m at a point where I’m unsure what’s causing the issue.
For reference, I have the environment implemented in SB3, and it works perfectly there. I wanted to extend it with RLlib and add a second agent. However, even when the second agent only has one available action, the setup doesn’t work.
That’s why I wanted to ask: does this custom model approach look reasonable, or am I fundamentally misunderstanding something?
I’d really appreciate any help or guidance you can offer. Thanks in advance!
BR
Roy