Trouble Migrating Multi-Agent PPO with Custom Model(Action Masking + CNN + MLP) to New RLlib API

1. Severity of the issue: (select one)

High: Completely blocks me.

2. Environment:

  • Ray version: 2.47.1
  • Python version: 3.11.7
  • OS: Windows 10 :frowning:

3. What happened vs. what you expected:

Hello everyone,

I’m trying to implement a multi-agent learning setup with action masking. I previously had this working using the old RLlib API stack, but now I’m upgrading to the new API and running into some trouble understanding how everything fits together.

From what I understood (though I might be wrong), I created a custom model that should handle action masking and process the observation in two separate branches:

  • A CNN to process the image part of the observation
  • A standard MLP to process the remaining vector-based observation
import torch
import torch.nn as nn
from ray.rllib.core.columns import Columns

from ray.rllib.core.rl_module.torch import TorchRLModule

from ray.rllib.utils.typing import TensorType
from typing import Any, Dict, Optional

class MaskedModel(nn.Module):
    def __init__(self, num_outputs):
        super().__init__()

        # CNN encoder for truncated_image (assumes [3, 32, 32])
        self.cnn = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=4, stride=2, padding=1),  # -> (32, 16, 16)
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1), # -> (64, 8, 8)
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1), # -> (128, 4, 4)
            nn.ReLU(),
            nn.Flatten()  # -> 128 * 4 * 4 = 2048
        )

        self.image_out_size = 2048

        # Fully connected encoder for vector input
        self.vector_input_size = (
            100 + 1 + 1  # softmax + truncated + prev_agent_action
        )

        self.vector_encoder = nn.Sequential(
            nn.Linear(self.vector_input_size, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
        )

        self.fusion_size = self.image_out_size + 64
        # Maps the concatenated vector to action logits.
        self.policy_head = nn.Linear(self.fusion_size, num_outputs)
        # Maps it to a scalar value (for PPO critic).
        self._value_branch = nn.Linear(self.fusion_size, 1)

    def forward(self, fused_features):
        """Forward pass that returns both logits and value."""
        logits = self.policy_head(fused_features)
        value = self._value_branch(fused_features).squeeze(-1)
        
        return logits, value

class MaskedPPOModule(TorchRLModule):
    def setup(self):
        self.model = MaskedModel(
            num_outputs=self.action_space.n,  # Use the action space to infer the number of output nodes.
        )

        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.model.to(self.device)

    def _forward(self, batch, **kwargs):
        obs = batch["obs"]
        action_mask = obs["action_mask"].float().to(self.device)
        fused = self.concat(obs)        

        logits, value = self.model(fused)

        # Apply action mask
        inf_mask = torch.clamp(torch.log(action_mask), min=-1e10)
        masked_logits = logits + inf_mask
        
        # see https://discuss.ray.io/t/keyerror-advantages-when-training-ppo-with-custom-model-in-rllib/21876/4
        return {
            Columns.ACTION_DIST_INPUTS: masked_logits,
            Columns.VF_PREDS: value,
            # These will be filled in by the PPO algorithm during training
            Columns.ADVANTAGES: torch.zeros_like(value),
            Columns.VALUE_TARGETS: torch.zeros_like(value),
        }
        
    def _forward_train(self, batch, **kwargs):
        obs = batch["obs"]
        action_mask = obs["action_mask"].float().to(self.device)
        fused = self.concat(obs)        

        logits, value = self.model(fused)

        # Apply action mask
        inf_mask = torch.clamp(torch.log(action_mask), min=-1e10)
        masked_logits = logits + inf_mask
        
        outs = super()._forward_train(batch, **kwargs)
        outs[Columns.ACTION_DIST_INPUTS] = masked_logits
        
        return outs
        
    # Get the joint feature vector.
    def concat(self, obs):
        # Ensure all tensors are on the correct self.device
        image = obs["truncated_image"].float().to(self.device) / 255.0
        softmax = obs["softmax"].float().to(self.device)
        truncated = obs["truncated"].float().unsqueeze(-1).to(self.device)
        prev_action = obs["prev_agent_action"].unsqueeze(-1).float().to(self.device)

        # Forward pass through CNN and vector encoder
        img_out = self.model.cnn(image)
        vector_input = torch.cat([softmax, truncated, prev_action], dim=-1)
        vector_out = self.model.vector_encoder(vector_input)
        
        # Concatenate visual and vector features
        return torch.cat([img_out, vector_out], dim=-1)
    
    def compute_values(
        self, batch: Dict[str, Any], embeddings: Optional[Any] = None
    ) -> TensorType:
        """Compute value estimates for the critic head."""
        obs = batch[Columns.OBS]
        fused = self.concat(obs)
        
        return self.model._value_branch(fused).squeeze(-1)

It took quite some effort to figure out what the _forward() method is supposed to return, so I’m not entirely sure whether I’m missing something or overwriting an important part of the pipeline.

My main problem now is that PPO doesn’t seem to learn anything—the output appears completely random. I’ve tried tweaking many things, but I’m at a point where I’m unsure what’s causing the issue.

For reference, I have the environment implemented in SB3, and it works perfectly there. I wanted to extend it with RLlib and add a second agent. However, even when the second agent only has one available action, the setup doesn’t work.

That’s why I wanted to ask: does this custom model approach look reasonable, or am I fundamentally misunderstanding something?

I’d really appreciate any help or guidance you can offer. Thanks in advance!

BR
Roy

Learned model implementation under the new API not too long ago myself. It’s hard to say what might be going wrong with your script without seeing how you instantiate everything via a config (though I may have missed something in your code). There’s a solid set of examples in the main repo here, and a direct example of action-masking here. I’d try putting your module into the format specified, checking whether that solves things, and then going from there.

Beyond that, useful to check whether the weights are changing (but not as desired), or whether your module is just not being properly hooked up at all. I remember a few months back, when I was trying to diagnose a lack of learning and it turned out my encoder was just not receiving gradients in the first place.

Could you maybe post a few more things like your action_mask tensor? Might be helpful for us to debug :slight_smile: or your PPO config?

As for your question, the _forward() method should return a dictionary with keys like Columns.ACTION_DIST_INPUTS and Columns.VF_PREDS . You just need to make sure your PPO is using them correctly.

Hi @MCW_Lad,

thanks a lot for your response.
I’m actually relieved that the issue isn’t something obviously wrong.

I did try following the action masking example, but I ran into a new challenge: since my observation is a dictionary, I believe I had to create a separate encoder to handle it properly. I attempted that, but it left me even more unsure if I was doing things correctly, since I’m still quite new to reinforcement learning.

Your suggestion about checking gradient flow was also helpful and something I hadn’t looked into yet. Unfortunately, that doesn’t seem to be the root cause either.

Hello @christina,

thanks again for all the help!

Here’s what my action_mask looks like for the first agent over one episode:

(MultiAgentEnvRunner pid=18580) Action Mask: tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]], device='cuda:0')
(MultiAgentEnvRunner pid=18580) Action Mask: tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 1., 1., 1., 1., 1.]], device='cuda:0')
(MultiAgentEnvRunner pid=18580) Action Mask: tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 1., 1., 1., 1.]], device='cuda:0')
(MultiAgentEnvRunner pid=18580) Action Mask: tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 1., 1., 1., 0.]], device='cuda:0')
(MultiAgentEnvRunner pid=18580) Action Mask: tensor([[1., 1., 1., 1., 1., 1., 1., 1., 0., 1., 0., 0., 1., 1., 1., 0.]], device='cuda:0')
(MultiAgentEnvRunner pid=18580) Action Mask: tensor([[1., 1., 1., 1., 0., 1., 1., 1., 0., 1., 0., 0., 1., 1., 1., 0.]], device='cuda:0')
(MultiAgentEnvRunner pid=18580) Action Mask: tensor([[1., 1., 1., 1., 0., 1., 1., 1., 0., 1., 0., 0., 0., 1., 1., 0.]], device='cuda:0')
(MultiAgentEnvRunner pid=18580) Action Mask: tensor([[1., 1., 1., 1., 0., 1., 1., 1., 0., 1., 0., 0., 0., 0., 1., 0.]], device='cuda:0')

The second agent always returns Action Mask: tensor([[1.]], device='cuda:0') as a kind of sanity check. I was hoping this minimal setup could still produce meaningful learning before scaling up to more actions. At least I was hoping, that it could work like this.

The model is also doing the sampling correct, so it applies the action mask correctly.

My PPO config looks like the following:

config = (
  PPOConfig()
  .environment("custom_classification_env", env_config={"resolution_scale": resolution_scale, "num_agents": 2})
  .framework("torch")
  .env_runners(
    num_env_runners=1,
    num_gpus_per_env_runner=0.5,
    num_cpus_per_env_runner=4,
    env_to_module_connector=passthrough_connector_factory, # Pass obs as dict.
    batch_mode="complete_episodes",
  )
  .resources(
    num_gpus=0.5,
  )
  .learners(
    num_gpus_per_learner=0.5,
  )
  .multi_agent(
    policies={"tile_selector","res_selector"},
    policy_mapping_fn=lambda aid, eps, **kw: (
      "res_selector" if aid.startswith("res_selector") else "tile_selector"
    ),
    policies_to_train=["tile_selector"] # TODO ["tile_selector", "res_selector"]
    )
    .rl_module(rl_module_spec=RLModuleSpec(
      module_class=MaskedPPOModule, # Apply action masking.
    )
  )
)

tune.run(
  "PPO",
  config=config.to_dict(),
  stop={"num_env_steps_sampled_lifetime": 1e6},
  checkpoint_freq=5e4 // (2048),
  storage_path="./log/", # log directory for TensorBoard
  name="tile_selector_test",
  checkpoint_at_end=True,
)

On top of that I’m using a connector to ensure observations stay as dicts:

import gymnasium as gym

from ray.rllib.connectors.connector_v2 import ConnectorV2
from ray.rllib.core.rl_module.rl_module import RLModule
from ray.rllib.utils.annotations import override

from ray.rllib.utils.typing import EpisodeType

from typing import Any, Dict, List, Optional


class PassThroughConnector(ConnectorV2):
    """A connector that does not modify the data it receives.
    see: https://github.com/ray-project/ray/blob/master/rllib/examples/connectors/prev_actions_prev_rewards.py
    and https://github.com/ray-project/ray/blob/master/rllib/connectors/env_to_module/prev_actions_prev_rewards.py
    """
    def recompute_output_observation_space(
        self,
        input_observation_space: gym.Space,
        input_action_space: gym.Space,
    ) -> gym.Space:
        return input_observation_space

    def __init__(
        self,
        input_observation_space: Optional[gym.Space] = None,
        input_action_space: Optional[gym.Space] = None,
        **kwargs,
    ):
        super().__init__(
            input_observation_space=input_observation_space,
            input_action_space=input_action_space,
            **kwargs,
        )

    def __call__(
        self,
        *,
        rl_module: RLModule,
        batch: Optional[Dict[str, Any]],
        episodes: List[EpisodeType],
        explore: Optional[bool] = None,
        shared_data: Optional[dict] = None,
        **kwargs,
    ) -> Any:

        return batch

Regarding your suggestion about _forward(), when I leave out Columns.ADVANTAGES, PPO throws a KeyError: 'advantages’. So it seems like the new API stack still requires that key, at least in my current configuration.

Thanks again and I appreciate any further insights or ideas!

1 Like

I didn’t spot anything that should cause issues (though I may just have missed it).

If I were facing this issue, my first thought would be to try running the action-masking module as-is on a toy environment, to see if the source of the issue really does have something to do with action-masking, and not with something else (hyperparameters, the multi-agent setup, or some other unforeseen complication).

If the issue is there, then you can probably get somewhere by gradually interpolating between your action-masking setup and the one in the examples, since we know that one should work. If the issue isn’t there, then, if possible, paring down your environment and gradually escalating the complexity until something breaks should help identify what’s going wrong.

Hi @royschn, from a glance I don’t see anything terribly wrong with your agent masks. :thinking: I do like the suggestion that @MCW_Lad made with the paring down the environment little by little and starting off with like, the base workable environment before adding layers of complexity to see if it’ll break. How else you are debugging rn?

1 Like

Hi @royschn,

I think this is the key to your issues. You cannot compute advantages from within forward. You need a trajectory to be able to compute the return.

I see from your comment in `_forward` that you think they are being computed / replaced later. Have you checked them? I don’t think they are or you would not be seeing the advantages key missing error to begin with.

If they are not being calculated and updated in post-processing then advantages of zero will mask all the gradients for the policy.