KeyError: 'advantages' on MARL

Hello, I trying to train a multi-agents env with the PPO algorithm. I have a bug during training that says that KeyError: ‘advantages’. I searched for many day but didn’t manage to find a solution. Can someone help me ?
Thanks.

Can you share your PPO Config setup and any scripts if available for reproducibility? Also, version of RLlib are you using?

Hello thanks for your response. Here is my RLModule class:

import torch

from ray.rllib.core.columns import Columns
from ray.rllib.core.rl_module.torch import TorchRLModule


class VPGTorchRLModule(TorchRLModule):
    """A simple VPG (vanilla policy gradient)-style RLModule for testing purposes.

    Use this as a minimum, bare-bones example implementation of a custom TorchRLModule.
    """

    def setup(self):
        super().setup()
        input_dim = flatten_space(self.observation_space).shape[0]#["observation"]
        hidden_dim = self.model_config["hidden_dim"]
        output_dim = 9#flatten_space(self.action_space["player1"]).shape[0]

        self._policy_net = torch.nn.Sequential(
            torch.nn.Linear(input_dim, hidden_dim),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_dim, output_dim),
            #torch.nn.Softmax()
        )

    def _forward(self, batch, **kwargs):
        # Push the observations from the batch through our `self._policy_net`.
        #print("Hello dear !!!!!!", batch)
        obs = batch[Columns.OBS]

        action_logits = self._policy_net(obs)
        # Return parameters for the (default) action distribution, which is
        # `TorchCategorical` (due to our action space being `gym.spaces.Discrete`).
        print("result sent:  ", {Columns.ACTION_DIST_INPUTS: action_logits})
        return {Columns.ACTION_DIST_INPUTS: action_logits}

And my config file:

config = (
    PPOConfig()
    # FrozenLake has a discrete observation space (ints).
    .environment(TicTacToe)
    #.env_runners(
    #    num_env_runners=1,
    #    batch_mode="complete_episodes"
    #)
    .multi_agent(
        policies={"p1", "p2"},
        policy_mapping_fn=lambda agent_id, episode, **kw: f"p{agent_id[-1]}",

    )
    .rl_module(
        rl_module_spec=MultiRLModuleSpec(
            rl_module_specs={
                "p1": RLModuleSpec(
                    module_class=VPGTorchRLModule,
                    model_config={"hidden_dim": 32},
                    observation_space=env.observation_spaces["player1"]
                ),
                "p2": RLModuleSpec(
                    module_class=VPGTorchRLModule,
                    model_config={"hidden_dim": 16},
                    observation_space=env.observation_spaces["player1"]
                ),
            }
        )
    )
    .training(use_critic=True, use_gae=True)  # Enable critic and GAE
)

And the way I train the model:

for i in range(5):
    train_results = ppo.train()

The environment is just the tic tac toe for test.

import gymnasium as gym
import numpy as np
from gymnasium.spaces import Dict

from ray.rllib.env.multi_agent_env import MultiAgentEnv


class TicTacToe(MultiAgentEnv):
    """A two-player game in which any player tries to complete one row in a 3x3 field.

    The observation space is Box(0.0, 1.0, (9,)), where each index represents a distinct
    field on a 3x3 board and values of 0.0 mean the field is empty, -1.0 means
    the opponend owns the field, and 1.0 means we occupy the field:
    ----------
    | 0| 1| 2|
    ----------
    | 3| 4| 5|
    ----------
    | 6| 7| 8|
    ----------

    The action space is Discrete(9) and actions landing on an already occupied field
    are simply ignored (and thus useless to the player taking these actions).

    Once a player completes a row, they receive +1.0 reward, the losing player receives
    -1.0 reward. In all other cases, both players receive 0.0 reward.
    """
    def __init__(self, config=None):
        super().__init__()

        # Define the agents in the game.
        self.agents = self.possible_agents = ["player1", "player2"]

        # Each agent observes a 9D tensor, representing the 3x3 fields of the board.
        # A 0 means an empty field, a 1 represents a piece of player 1, a -1 a piece of
        # player 2.
        self.observation_spaces = Dict({
            "player1": gym.spaces.Box(-1.0, 1.0, (9,), np.float32),
            "player2": gym.spaces.Box(-1.0, 1.0, (9,), np.float32),
        })
        # Each player has 9 actions, encoding the 9 fields each player can place a piece
        # on during their turn.
        self.action_spaces = {
            "player1": gym.spaces.Discrete(9),
            "player2": gym.spaces.Discrete(9),
        }

        self.board = None
        self.current_player = None

    def reset(self, *, seed=None, options=None):
        self.board = [
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
            0,
        ]
        # Pick a random player to start the game.
        self.current_player = np.random.choice(["player1", "player2"])
        # Return observations dict (only with the starting player, which is the one
        # we expect to act next).
        return {
            self.current_player: np.array(self.board, np.float32),
        }, {}

    def step(self, action_dict):
        action = action_dict[self.current_player]

        # Create a rewards-dict (containing the rewards of the agent that just acted).
        rewards = {self.current_player: 0.0}
        # Create a terminateds-dict with the special `__all__` agent ID, indicating that
        # if True, the episode ends for all agents.
        terminateds = {"__all__": False}

        opponent = "player1" if self.current_player == "player2" else "player2"

        # Penalize trying to place a piece on an already occupied field.
        if self.board[action] != 0:
            rewards[self.current_player] -= 5.0
        # Change the board according to the (valid) action taken.
        else:
            self.board[action] = 1 if self.current_player == "player1" else -1

            # After having placed a new piece, figure out whether the current player
            # won or not.
            if self.current_player == "player1":
                win_val = [1, 1, 1]
            else:
                win_val = [-1, -1, -1]
            if (
                # Horizontal win.
                self.board[:3] == win_val
                or self.board[3:6] == win_val
                or self.board[6:] == win_val
                # Vertical win.
                or self.board[0:7:3] == win_val
                or self.board[1:8:3] == win_val
                or self.board[2:9:3] == win_val
                # Diagonal win.
                or self.board[::3] == win_val
                or self.board[2:7:2] == win_val
            ):
                # Final reward is +5 for victory and -5 for a loss.
                rewards[self.current_player] += 5.0
                rewards[opponent] = -5.0
                print("the winner is: ", self.current_player)
                # Episode is done and needs to be reset for a new game.
                terminateds["__all__"] = True

            # The board might also be full w/o any player having won/lost.
            # In this case, we simply end the episode and none of the players receives
            # +1 or -1 reward.
            elif 0 not in self.board:
                terminateds["__all__"] = True

        # Flip players and return an observations dict with only the next player to
        # make a move in it.
        self.current_player = opponent

        return (
            {self.current_player: np.array(self.board, np.float32)},
            rewards,
            terminateds,
            {},
            {},
        )

Hope it’s clear

I suspect that I have to implement a value function to be able to use PPO. But when I implement it in the RL module the algorithm don’t use it.
I use version: 2.44.1
Thanks.

Here is the solution I found thanks to other comments on the forum.

import torch

from ray.rllib.core.columns import Columns
from ray.rllib.core.rl_module.torch import TorchRLModule
from ray.rllib.utils.annotations import override
from ray.rllib.core.rl_module.apis.value_function_api import ValueFunctionAPI
from gymnasium.spaces.utils import flatten_space
from ray.rllib.utils.torch_utils import flatten_inputs_to_1d_tensor
from ray.rllib.models.torch.torch_distributions import TorchCategorical
from ray.rllib.utils.torch_utils import FLOAT_MIN
class VPGTorchRLModule(TorchRLModule, ValueFunctionAPI, ):
    """A simple VPG (vanilla policy gradient)-style RLModule for testing purposes.

    Use this as a minimum, bare-bones example implementation of a custom TorchRLModule.
    """
    @override(TorchRLModule)
    def setup(self):
        super().setup()
        input_dim = flatten_space(self.observation_space["observation"]).shape[0]#["observation"]
        hidden_dim_actor = self.model_config["hidden_dim_actor"]
        hidden_dim_critic = self.model_config["hidden_dim_critic"]
        output_dim = flatten_space(self.action_space).shape[0]#["action"]

        self._policy_net = torch.nn.Sequential(
            torch.nn.Linear(input_dim, hidden_dim_actor),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_dim_actor, output_dim),
            #torch.nn.Softmax(dim=-1)  # Use softmax to get a probability distribution over actions
        )
        self._value_function = torch.nn.Sequential(
            torch.nn.Linear(input_dim, hidden_dim_critic),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_dim_critic, 1)
        )
    @override(TorchRLModule)
    def _forward_exploration(self, batch, **kwargs):
        observation, action_mask = batch[Columns.OBS]["observation"], batch[Columns.OBS]["action_mask"]
        #print("observation: ", observation)
        # Convert action mask into an `[0.0][-inf]`-type mask.
        with torch.no_grad():
            inf_mask = torch.clamp(torch.log(action_mask), min=FLOAT_MIN)
            # Get the logits from the policy network.
            logits = self._policy_net(flatten_inputs_to_1d_tensor(observation, self.observation_space["observation"]))  #flatten_space(observation)
            # Mask the logits.
            logits_masked = logits + inf_mask
        return {Columns.ACTION_DIST_INPUTS: logits_masked}
    @override(TorchRLModule)
    def _forward_inference(self, batch, **kwargs):
        observation, action_mask = batch[Columns.OBS]["observation"], batch[Columns.OBS]["action_mask"]
        #print("observation: ", observation)
        # Convert action mask into an `[0.0][-inf]`-type mask.
        with torch.no_grad():
            inf_mask = torch.clamp(torch.log(action_mask), min=FLOAT_MIN)
            # Get the logits from the policy network.
            logits = self._policy_net(flatten_inputs_to_1d_tensor(observation, self.observation_space["observation"]))#flatten_space(observation)
            # Mask the logits.
            logits_masked = logits + inf_mask
        return {Columns.ACTION_DIST_INPUTS: logits_masked}
    """
    @override(TorchRLModule)
    def _forward(self, batch, **kwargs):
        observation, action_mask = batch[Columns.OBS]["observation"], batch[Columns.OBS]["action_mask"]
        # Convert action mask into an `[0.0][-inf]`-type mask.
        inf_mask = torch.clamp(torch.log(action_mask), min=FLOAT_MIN)
        # Get the logits from the policy network.
        logits = self._policy_net(observation)
        # Mask the logits.
        logits_masked = logits + inf_mask
        
        if len(batch) > 1:
            torch.set_printoptions(profile="full")
            print("logits masked: ", logits_masked)
        return {Columns.ACTION_DIST_INPUTS: logits_masked}
        """
    @override(TorchRLModule)
    def _forward_train(self, batch, **kwargs):
        observation, action_mask = batch[Columns.OBS]["observation"], batch[Columns.OBS]["action_mask"]
        #print("observation: ", observation)
        # Convert action mask into an `[0.0][-inf]`-type mask.
        inf_mask = torch.clamp(torch.log(action_mask), min=FLOAT_MIN)
        # Get the logits from the policy network.
        logits = self._policy_net(flatten_inputs_to_1d_tensor(observation, self.observation_space["observation"]))#flatten_space(observation)
        # Mask the logits.
        logits_masked = logits + inf_mask
        return {Columns.ACTION_DIST_INPUTS: logits_masked}
    @override(ValueFunctionAPI)
    def compute_values(self, batch, **kwargs):
        return self._value_function(flatten_inputs_to_1d_tensor(batch[Columns.OBS]["observation"], self.observation_space["observation"])).squeeze(-1)

You should define the compute_values from ValueFunctionAPI if you’re using PPO algorithm.