Simple multi agent setup with action masking problems

How severe does this issue affect your experience of using Ray?

  • High: It blocks me to complete my task.

Has anyone been successful in making a MARL setup with action masking using the new API?

I’m on my second attempt, this time trying to adapt the current examples to multi agent (before that I was trying using my own Zookeeper environment, custom RLModule and PPOCatalogue).

Anyway, where I’m stuck now is that I’m getting a NotImplementedError error throw in batch_individual_items.py". I’ll let some code follow along with the stack trace. If anyone has any hints of getting past this, or has an example of a working MARL setup with Action masking (even better), please comment :slight_smile:

main.py:

from gymnasium.spaces import Box, Discrete

from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.core.rl_module.rl_module import RLModuleSpec
from ray.rllib.examples.envs.classes.action_mask_env import ActionMaskEnv
from ray.rllib.examples.rl_modules.classes.action_masking_rlm import (
    ActionMaskingTorchRLModule,
)
from ray.rllib.core.rl_module.multi_rl_module import MultiRLModuleSpec

from gymnasium.spaces import Dict

from ray.rllib.utils.test_utils import (
    add_rllib_example_script_args,
    run_rllib_example_script_experiment,
)


parser = add_rllib_example_script_args(
    default_iters=10,
    default_timesteps=100000,
    default_reward=150.0,
)

spec = MultiRLModuleSpec(
                module_specs={
                    "red": RLModuleSpec(
                        module_class=ActionMaskingTorchRLModule,
                        observation_space=Dict(
                            {
                                "action_mask": Box(0.0, 1.0, shape=(2,)),
                                "observations": Box(-1.0, 1.0, (5,)),
                            }
                        )
                    ),
                    "blue": RLModuleSpec(
                        module_class=ActionMaskingTorchRLModule,
                        observation_space=Dict(
                            {
                                "action_mask": Box(0.0, 1.0, shape=(2,)),
                                "observations": Box(-1.0, 1.0, (5,)),
                            }
                        )
                    )
                }
            )

if __name__ == "__main__":
    args = parser.parse_args()

    if args.algo != "PPO":
        raise ValueError("This example only supports PPO. Please use --algo=PPO.")

    base_config = (
        PPOConfig()
        .api_stack(
            # This example runs only under the new pai stack.
            enable_rl_module_and_learner=True,
            enable_env_runner_and_connector_v2=True,
        )
        .environment(
            env=ActionMaskEnv,
            env_config={
                "action_space": Discrete(100),
                # This defines the 'original' observation space that is used in the
                # `RLModule`. The environment will wrap this space into a
                # `gym.spaces.Dict` together with an 'action_mask' that signals the
                # `RLModule` to adapt the action distribution inputs for the underlying
                # `PPORLModule`.

                "observation_space": Box(-1.0, 1.0, (5,)),
            },
        )
        .multi_agent(
            policies={"red", "blue"},
            policy_mapping_fn=(lambda aid, *args, **kwargs: aid),
            policies_to_train={"red", "blue"},
        )
        .rl_module(
            model_config_dict={
                "post_fcnet_hiddens": [64, 64],
                "post_fcnet_activation": "relu",
            },
            # We need to explicitly specify here RLModule to use and
            # the catalog needed to build it.
            rl_module_spec=spec,
        )
        .evaluation(
            evaluation_num_env_runners=1,
            evaluation_interval=1,
            # Run evaluation parallel to training to speed up the example.
            evaluation_parallel_to_training=True,
        )
    )

    # Run the example (with Tune).
    run_rllib_example_script_experiment(base_config, args)

Env (based on this).

from gymnasium.spaces import Box, Dict
import numpy as np

from ray.rllib.examples.envs.classes.random_env import RandomMultiAgentEnv

class ActionMaskEnv(RandomMultiAgentEnv):
    """A randomly acting environment that publishes an action-mask each step."""

    def __init__(self, config):
        super().__init__(config)

        self.possible_agents = ["red", "blue"]

        # Masking only works for Discrete actions.
        assert isinstance(self.action_space, Dict)
        # Add action_mask to observations.
        self.unwrapped_space = Dict(
            {
                "action_mask": Box(0.0, 1.0, shape=(2,)),
                "observations": self.observation_space,
            }
        )
        self.observation_space = self.unwrapped_space
        self.observation_spaces = {name: self.unwrapped_space for name in self.possible_agents}
        self.valid_actions = None

    def reset(self, *, seed=None, options=None):
        obs, info = self.observation_space.sample(), {}
        self._fix_action_mask(obs)
        return obs, info
        
    def step(self, action):
        # Check whether action is valid.
        if not self.valid_actions[action]:
            raise ValueError(
                f"Invalid action ({action}) sent to env! "
                f"valid_actions={self.valid_actions}"
            )
        obs, rew, done, truncated, info = super().step(action)
        self._fix_action_mask(obs)
        return obs, rew, done, truncated, info

    def _fix_action_mask(self, obs):
        # Fix action-mask: Everything larger 0.5 is 1.0, everything else 0.0.
        self.valid_actions = np.round(obs["action_mask"])
        obs["action_mask"] = self.valid_actions

action_masking_rlm.py (based on this):

import gymnasium as gym
from typing import Dict, Optional, Tuple

from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import PPOTorchRLModule
from ray.rllib.core.columns import Columns
from ray.rllib.core.rl_module.apis.value_function_api import ValueFunctionAPI
from ray.rllib.core.rl_module.rl_module import RLModule, RLModuleConfig
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.torch_utils import FLOAT_MIN
from ray.rllib.utils.typing import TensorType

torch, nn = try_import_torch()


class ActionMaskingRLModule(RLModule):
    """An RLModule that implements an action masking for safe RL.

    This RLModule implements action masking to avoid unsafe/unwanted actions
    dependent on the current state (observations). It does so by using an
    environment generated action mask defining which actions are allowed and
    which should be avoided. The action mask is extracted from the
    environment's `gymnasium.spaces.dict.Dict` observation and applied after
    the module's `forward`-pass to the action logits. The resulting action
    logits prevent unsafe/unwanted actions to be sampled from the corresponding
    action distribution.

    Note, this RLModule is implemented for the `PPO` algorithm only. It is not
    guaranteed to work with other algorithms. Furthermore, not that for this
    module to work it requires an environment with a `gymnasium.spaces.dict.Dict`
    observation space containing tow key, `"action_mask"` and `"observations"`.
    """

    @override(RLModule)
    def __init__(self, config: RLModuleConfig):
        # If observation space is not of type `Dict` raise an error.
        if not isinstance(config.observation_space, gym.spaces.dict.Dict):
            raise ValueError(
                "This RLModule requires the environment to provide a "
                "`gym.spaces.Dict` observation space of the form: \n"
                " {'action_mask': Box(0.0, 1.0, shape=(self.action_space.n,)),"
                "  'observation_space': self.observation_space}",
                f"Got: {config.observation_space}"
            )

        # While the environment holds an observation space that contains, both,
        # the action mask and the original observation space, the 'RLModule'
        # receives only the `"observation"` element of the space, but not the
        # action mask.
        self.observation_space_with_mask = config.observation_space
        config.observation_space = config.observation_space["observations"]

        # Keeps track if observation specs have been checked already.
        self._checked_observations = False

        # The PPORLModule, in its constructor will build networks for the original
        # observation space (i.e. without the action mask).
        super().__init__(config)


class ActionMaskingTorchRLModule(ActionMaskingRLModule, PPOTorchRLModule):
    @override(PPOTorchRLModule)
    def setup(self):
        super().setup()
        # We need to reset here the observation space such that the
        # super`s (`PPOTorchRLModule`) observation space is the
        # original space (i.e. without the action mask) and `self`'s
        # observation space contains the action mask.
        self.config.observation_space = self.observation_space_with_mask

    @override(PPOTorchRLModule)
    def _forward_inference(
        self, batch: Dict[str, TensorType], **kwargs
    ) -> Dict[str, TensorType]:
        # Preprocess the original batch to extract the action mask.
        action_mask, batch = self._preprocess_batch(batch)
        # Run the forward pass.
        outs = super()._forward_inference(batch, **kwargs)
        # Mask the action logits and return.
        return self._mask_action_logits(outs, action_mask)

    @override(PPOTorchRLModule)
    def _forward_exploration(
        self, batch: Dict[str, TensorType], **kwargs
    ) -> Dict[str, TensorType]:
        # Preprocess the original batch to extract the action mask.
        action_mask, batch = self._preprocess_batch(batch)
        # Run the forward pass.
        outs = super()._forward_exploration(batch, **kwargs)
        # Mask the action logits and return.
        return self._mask_action_logits(outs, action_mask)

    @override(PPOTorchRLModule)
    def _forward_train(
        self, batch: Dict[str, TensorType], **kwargs
    ) -> Dict[str, TensorType]:
        # Preprocess the original batch to extract the action mask.
        action_mask, batch = self._preprocess_batch(batch)
        # Run the forward pass.
        outs = super()._forward_train(batch, **kwargs)
        # Mask the action logits and return.
        return self._mask_action_logits(outs, action_mask)

    @override(ValueFunctionAPI)
    def compute_values(self, batch: Dict[str, TensorType]):
        # Preprocess the batch to extract the `observations` to `Columns.OBS`.
        _, batch = self._preprocess_batch(batch)
        # Call the super's method to compute values for GAE.
        return super().compute_values(batch)

    def _preprocess_batch(
        self, batch: Dict[str, TensorType], **kwargs
    ) -> Tuple[TensorType, Dict[str, TensorType]]:
        """Extracts observations and action mask from the batch

        Args:
            batch: A dictionary containing tensors (at least `Columns.OBS`)

        Returns:
            A tuple with the action mask tensor and the modified batch containing
                the original observations.
        """
        # Check observation specs for action mask and observation keys.
        self._check_batch(batch)

        # Extract the available actions tensor from the observation.
        action_mask = batch[Columns.OBS].pop("action_mask")

        # Modify the batch for the `PPORLModule`'s `forward` method, i.e.
        # pass only `"obs"` into the `forward` method.
        batch[Columns.OBS] = batch[Columns.OBS].pop("observations")

        # Return the extracted action mask and the modified batch.
        return action_mask, batch

    def _mask_action_logits(
        self, batch: Dict[str, TensorType], action_mask: TensorType
    ) -> Dict[str, TensorType]:
        """Masks the action logits for the output of `forward` methods

        Args:
            batch: A dictionary containing tensors (at least action logits).
            action_mask: A tensor containing the action mask for the current
                observations.

        Returns:
            A modified batch with masked action logits for the action distribution
            inputs.
        """
        # Convert action mask into an `[0.0][-inf]`-type mask.
        inf_mask = torch.clamp(torch.log(action_mask), min=FLOAT_MIN)

        # Mask the logits.
        batch[Columns.ACTION_DIST_INPUTS] += inf_mask

        # Return the batch with the masked action logits.
        return batch

    def _check_batch(self, batch: Dict[str, TensorType]) -> Optional[ValueError]:
        """Assert that the batch includes action mask and observations.

        Args:
            batch: A dicitonary containing tensors (at least `Columns.OBS`) to be
                checked.

        Raises:
            `ValueError` if the column `Columns.OBS`  does not contain observations
                and action mask.
        """
        if not self._checked_observations:
            if "action_mask" not in batch[Columns.OBS]:
                raise ValueError(
                    "No action mask found in observation. This `RLModule` requires "
                    "the environment to provide observations that include an "
                    "action mask (i.e. an observation space of the Dict space "
                    "type that looks as follows: \n"
                    "{'action_mask': Box(0.0, 1.0, shape=(self.action_space.n,)),"
                    "'observations': self.observation_space}"
                )
            if "observations" not in batch[Columns.OBS]:
                raise ValueError(
                    "No observations found in observation. This 'RLModule` requires "
                    "the environment to provide observations that include the original "
                    "observations under a key `'observations'` in a dict (i.e. an "
                    "observation space of the Dict space type that looks as follows: \n"
                    "{'action_mask': Box(0.0, 1.0, shape=(self.action_space.n,)),"
                    "'observations': <observation_space>}"
                )
            self._checked_observations = True

type or paste code here

STACK TRACE:

  File "/home/someone/.conda/envs/python310/lib/python3.10/site-packages/ray/rllib/utils/actor_manager.py", line 181, in apply
    return func(self, *args, **kwargs)
  File "/home/someone/.conda/envs/python310/lib/python3.10/site-packages/ray/rllib/algorithms/algorithm.py", line 1439, in _env_runner_remote
    episodes = worker.sample(
  File "/home/someone/.conda/envs/python310/lib/python3.10/site-packages/ray/rllib/env/multi_agent_env_runner.py", line 184, in sample
    samples = self._sample_episodes(
  File "/home/someone/.conda/envs/python310/lib/python3.10/site-packages/ray/rllib/env/multi_agent_env_runner.py", line 476, in _sample_episodes
    to_module = self._env_to_module(
  File "/home/someone/.conda/envs/python310/lib/python3.10/site-packages/ray/rllib/connectors/env_to_module/env_to_module_pipeline.py", line 25, in __call__
    return super().__call__(
  File "/home/someone/.conda/envs/python310/lib/python3.10/site-packages/ray/rllib/connectors/connector_pipeline_v2.py", line 85, in __call__
    data = connector(
  File "/home/someone/.conda/envs/python310/lib/python3.10/site-packages/ray/rllib/connectors/common/batch_individual_items.py", line 203, in __call__
    raise NotImplementedError
NotImplementedError

The data object going into the call method:

data = {'observations': {'obs': [OrderedDict([(0, array([-0.84334373, -0.6655395 , -0.40656194,  0.48559263,  0.3246651 ], dtype=float32))])]}, 'action_mask': {'obs': [array([0.93138945, 0.14899734], dtype=float32)]}}
is_multi_rl_module = True
rl_module = RL MODULE: MARL({'blue': ActionMaskingTorchRLModule(
   (encoder): TorchActorCriticEncoder(
     (actor_encoder): TorchMLPEncoder(
       (net): TorchMLP(
         (mlp): Sequential(
           (0): Linear(in_features=5, out_features=256, bias=True)
           (1): Tanh()
           (2): Linear(in_features=256, out_features=256, bias=True)
           (3): Tanh()
         )
       )
     )
     (critic_encoder): TorchMLPEncoder(
       (net): TorchMLP(
         (mlp): Sequential(
           (0): Linear(in_features=5, out_features=256, bias=True)
           (1): Tanh()
           (2): Linear(in_features=256, out_features=256, bias=True)
           (3): Tanh()
         )
       )
     )
   )
   (pi): TorchMLPHead(
     (net): TorchMLP(
       (mlp): Sequential(
         (0): Linear(in_features=256, out_features=64, bias=True)
         (1): ReLU()

         (2): Linear(in_features=64, out_features=64, bias=True)
         (3): ReLU()
         (4): Linear(in_features=64, out_features=100, bias=True)
       )
     )
   )
   (vf): TorchMLPHead(
     (net): TorchMLP(
       (mlp): Sequential(
         (0): Linear(in_features=256, out_features=64, bias=True)
         (1): ReLU()
         (2): Linear(in_features=64, out_features=64, bias=True)
         (3): ReLU()
         (4): Linear(in_features=64, out_features=1, bias=True)
       )
     )
   )
 ),
  'red': ActionMaskingTorchRLModule(
   (encoder): TorchActorCriticEncoder(
     (actor_encoder): TorchMLPEncoder(
       (net): TorchMLP(
         (mlp): Sequential(
           (0): Linear(in_features=5, out_features=256, bias=True)
           (1): Tanh()
           (2): Linear(in_features=256, out_features=256, bias=True)
           (3): Tanh()
         )
       )
     )
     (critic_encoder): TorchMLPEncoder(
       (net): TorchMLP(
         (mlp): Sequential(
           (0): Linear(in_features=5, out_features=256, bias=True)
           (1): Tanh()
           (2): Linear(in_features=256, out_features=256, bias=True)
           (3): Tanh()
         )
       )
     )
   )
   (pi): TorchMLPHead(
     (net): TorchMLP(
       (mlp): Sequential(
         (0): Linear(in_features=256, out_features=64, bias=True)
         (1): ReLU()
         (2): Linear(in_features=64, out_features=64, bias=True)
         (3): ReLU()
         (4): Linear(in_features=64, out_features=100, bias=True)
       )
     )
   )
   (vf): TorchMLPHead(
     (net): TorchMLP(
       (mlp): Sequential(
         (0): Linear(in_features=256, out_features=64, bias=True)
         (1): ReLU()
         (2): Linear(in_features=64, out_features=64, bias=True)
         (3): ReLU()
         (4): Linear(in_features=64, out_features=1, bias=True)
       )
     )
   )
 )})