KeyError: 'advantages' in PPO MARL

I use ray 2.50.1 to implement a MARL model using PPO.
However, I meet the following problem:

'advantages'
KeyError: 'advantages'

During handling of the above exception, another exception occurred:

  File "/home/tangjintong/multi_center_1020/main.py", line 267, in <module>
    result = algo.train()
             ^^^^^^^^^^^^
KeyError: 'advantages'

My model is:

import numpy as np
import matplotlib.pyplot as plt
from torch import nn
import os
from gymnasium import spaces
import ray
from ray import tune
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.env.multi_agent_env import MultiAgentEnv
from ray.rllib.core.rl_module.torch import TorchRLModule
from ray.rllib.utils.typing import TensorType
from ray.rllib.core.rl_module.rl_module import RLModuleSpec
from ray.rllib.core import Columns
from ray.rllib.utils.annotations import override
from ray.rllib.core.rl_module.apis.value_function_api import ValueFunctionAPI


class MaskedRLModule(TorchRLModule):
    def setup(self):
        super().setup()
        input_dim = self.observation_space['obs'].n
        hidden_dim = self.model_config["hidden_dim"]
        output_dim = self.action_space.n
        self.policy_net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )
        self.value_net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )

    def _forward(self, batch: TensorType, **kwargs) -> TensorType:
        # batch["obs"] shape: [B, obs_size]
        logits = self.policy_net(batch["obs"]["obs"].float())
        # Handle action masking
        if "action_mask" in batch["obs"]:
            mask = batch["obs"]["action_mask"]
            # Set logits of invalid actions to -inf
            logits = logits.masked_fill(mask == 0, -1e9)
        return {Columns.ACTION_DIST_INPUTS: logits}

    @override(ValueFunctionAPI)
    def compute_values(self, batch, **kwargs):
        return self.value_net(batch["obs"]["obs"].float())


class Grid9x9MultiAgentEnv(MultiAgentEnv):
    """9x9 discrete grid multi-agent environment (2 homogeneous agents)."""

    def __init__(self, env_config=None):
        super().__init__()
        env_config = env_config or {}
        self._num_agents = env_config.get("num_agents")         # Use private variable for agent count to avoid errors
        self.agents = self.possible_agents = [f"agent_{i}" for i in range(self._num_agents)]
        self.render_step_num = env_config.get("render_step_num")
        self.truncation_step_num = env_config.get("truncation_step_num")
        self.size = env_config.get("size")

        self.grid = np.zeros((self.size, self.size), dtype=np.int8)  # 0=empty, 1=occupied
        self.agent_positions = {agent: None for agent in self.agents}
        self._update_masks()

        self.step_in_episode = 0
        self.current_total_step = 0
        # Both action and observation spaces are discrete grids of size 9*9
        self.action_space = spaces.Dict({
            f"agent_{i}": spaces.Discrete(self.size * self.size)
            for i in range(self._num_agents)
        })

        self.observation_space = spaces.Dict({
            f"agent_{i}": spaces.Dict({
                "obs": spaces.Discrete(self.size * self.size),
                "action_mask": spaces.Discrete(self.size * self.size),
            })
            for i in range(self._num_agents)
        })

        coords = np.array([(i, j) for i in range(self.size) for j in range(self.size)])  # 81×2, each row is (row, col)
        # Calculate Euclidean distance matrix
        diff = coords[:, None, :] - coords[None, :, :]  # 81×81×2
        self.distance_matrix = np.sqrt((diff ** 2).sum(-1))  # 81×81

    def reset(self, *, seed=None, options=None):
        super().reset(seed=seed)
        print(f"Environment reset at step {self.current_total_step}.")
        self.grid = np.zeros((self.size, self.size), dtype=np.int8)  # 0=empty, 1=occupied
        self.agent_positions = {agent: None for agent in self.agents}
        self._update_masks()
        self.step_in_episode = 0

        obs = {agent: self._get_obs(agent) for agent in self.agents}
        return obs, {}

    def _update_masks(self):
        """Update action masks: cannot select occupied cells."""
        mask = 1 - self.grid.flatten()  # 1 indicates available positions, 0 indicates unavailable positions
        self.current_masks = {agent: mask.copy() for agent in self.agents}

        # If both agents have chosen positions, mutually prohibit selecting the same position
        for agent, pos in self.agent_positions.items():
            if pos is not None:
                for other in self.agents:
                    if other != agent:
                        self.current_masks[other][pos] = 0

    def _get_obs(self, agent):
        return { 
            "obs": self.grid.flatten().astype(np.float32),
            "action_mask": self.current_masks[agent].astype(np.float32),
        }

    def step(self, actions):
        """actions is a dict: {agent_0: act0, agent_1: act1}"""
        rewards = {agent: 0.0 for agent in self.agents}
        terminations = {agent: False for agent in self.agents}
        truncations = {agent: False for agent in self.agents}
        infos = {agent: {} for agent in self.agents}

        # Check for action conflicts and update grid and agent_positions
        chosen_positions = set()
        for agent, act in actions.items():
            if self.current_masks[agent][act] == 0:
                rewards[agent] = -1.0
            else:
                if act in chosen_positions:
                    # Conflicting position, keep agent_position[agent] unchanged
                    rewards[agent] = -1.0
                else:
                    if self.agent_positions[agent] is not None:
                        row, col = divmod(self.agent_positions[agent], self.size)
                        self.grid[row, col] = 0  # Release previous position
                    row, col = divmod(act, self.size)
                    self.grid[row, col] = 1  # Occupy new position
                    self.agent_positions[agent] = act
                    chosen_positions.add(act)

        rewards = self.reward()

        self._update_masks()
        obs = {agent: self._get_obs(agent) for agent in self.agents}

        self.step_in_episode += 1
        self.current_total_step += 1
        
        # When any agent terminates, e.g., the entire episode terminates:
        if self.step_in_episode >= self.truncation_step_num:
            for agent in self.agents:
                terminations[agent] = True
                truncations[agent] = True
                self.visualize()

        # "__all__" must exist and be accurate
        terminations["__all__"] = all(terminations[a] for a in self.agents)
        truncations["__all__"] = all(truncations[a] for a in self.agents)

        return obs, rewards, terminations, truncations, infos

    def reward(self):
        """
        Reward function: The reward for a merchant's chosen cell is the total number of customers served * product price.
        Customer cost is transportation cost (related to distance) + product price, so customers only choose the merchant that minimizes their cost.
        Since merchants have the same product price, customers choose the nearest merchant.
        Therefore, each merchant wants their chosen cell to cover more customers.
        Simplified here: reward equals the number of customers covered by that merchant.
        """
        positions = list(self.agent_positions.values())
        # Get covered customers (i.e., customers closer to this merchant)
        customer_agent = np.argmin(self.distance_matrix[positions], axis=0)
        # Count the number of customers corresponding to each agent as reward
        values, counts = np.unique(customer_agent, return_counts=True)
        return {f"agent_{v}": counts[i] for i, v in enumerate(values)}

    def visualize(self):
        n = self.size
        fig, ax = plt.subplots(figsize=(6, 6))

        # Draw grid lines
        for x in range(n + 1):
            ax.axhline(x, color='k', lw=1)
            ax.axvline(x, color='k', lw=1)

        # Draw occupied positions
        for pos in self.agent_positions.values():
            row, col = divmod(pos, n)
            ax.add_patch(plt.Rectangle((col, n - 1 - row), 1, 1, color='lightgray'))

        # Draw agents
        colors = ["red", "blue"]
        for i, (agent, pos) in enumerate(self.agent_positions.items()):
            row, col = divmod(pos, n)
            ax.scatter(col + 0.5, n - 1 - row + 0.5, c=colors[i], s=200, label=agent)

        ax.set_xlim(0, n)
        ax.set_ylim(0, n)
        ax.set_xticks([])
        ax.set_yticks([])
        ax.set_aspect('equal')
        ax.legend(bbox_to_anchor=(1.05, 1), loc='upper right')
        if not os.path.exists("figures"):
            os.makedirs("figures")
        plt.savefig(f"figures/grid_step_{self.current_total_step}.png")
        plt.close()


if __name__ == "__main__":
    ray.init(ignore_reinit_error=True)

    env_name = "Grid9x9MultiAgentEnv"
    tune.register_env(env_name, lambda cfg: Grid9x9MultiAgentEnv(cfg))

    def policy_mapping_fn(agent_id, episode, **kwargs):
        # Homogeneous agents share one policy
        return "shared_policy"

    env_config = {
        # Environment parameters can be passed here
        "render_step_num": 500,
        "truncation_step_num": 500,
        "num_agents": 2,
        "size": 9,
    }

    model_config = {
        "hidden_dim": 128,
    }

    config = (
        PPOConfig()
        .environment(
            env=env_name,
            env_config=env_config
        )
        .multi_agent(
            policies={"shared_policy"},
            policy_mapping_fn=policy_mapping_fn,
        )
        .rl_module(
            rl_module_spec=RLModuleSpec(
                module_class=MaskedRLModule,
                model_config=model_config,
            )
        )
        .framework("torch")
        .env_runners(
            num_env_runners=1,                  # Number of parallel environments
            rollout_fragment_length=50,         # Sampling fragment length
            batch_mode="truncate_episodes",     # Sampling mode: collect a complete episode as a batch
            add_default_connectors_to_env_to_module_pipeline=True,
            add_default_connectors_to_module_to_env_pipeline=True
        )
        .resources(num_gpus=1)
        .training(
            train_batch_size=1000,  # Minimum number of experience steps to collect before each update
            minibatch_size=128,     # Number of steps per minibatch during update
            lr=1e-4,  # Learning rate
            use_gae=True,
            use_critic=True,
        )
    )
    algo = config.build_algo()
    print("Start training...")
    for i in range(5):
        result = algo.train()
        print(f"Iteration {i}: reward={result['episode_reward_mean']}")

I have read some posts about this problem but none of them helps. Any help would be thankful!

The KeyError: ‘advantages’ in Ray RLlib PPO with a custom RLModule usually means that the ‘advantages’ field is missing from the training batch, which is required for PPO loss computation. This is almost always because the value function predictions (VF_PREDS) are not being returned by your RLModule’s forward methods, so RLlib cannot compute advantages during postprocessing. In your MaskedRLModule, your _forward method only returns Columns.ACTION_DIST_INPUTS, but not Columns.VF_PREDS, which is required for GAE and PPO to work (see this discussion, Ray RLlib docs).

To fix this, update your _forward (and _forward_train, if you override it) to also return Columns.VF_PREDS, e.g.:

def _forward(self, batch: TensorType, **kwargs) -> TensorType:
    logits = self.policy_net(batch["obs"]["obs"].float())
    values = self.value_net(batch["obs"]["obs"].float()).squeeze(-1)
    if "action_mask" in batch["obs"]:
        mask = batch["obs"]["action_mask"]
        logits = logits.masked_fill(mask == 0, -1e9)
    return {
        Columns.ACTION_DIST_INPUTS: logits,
        Columns.VF_PREDS: values,
    }

This will allow RLlib to compute advantages and avoid the KeyError (see example fix, Ray RLlib postprocessing).

Would you like a step-by-step explanation of why this is required and how RLlib computes advantages?

Sources:

Hint: Mention @RunLLM in the post for followups.