Any examples of multi-agent with action maksing inference?

Context

I’m trying to set up inference for a multi-agent environment with action masking. While the learning process appears to work well, I’m encountering issues during inference setup. I’ve attempted two different approaches, both resulting in errors.

Approach 1: Using Connector Pipelines

I tried following the official example, but encountered issues. Here’s my implementation:

import os
from ray.rllib.connectors.env_to_module import EnvToModulePipeline
from ray.rllib.connectors.module_to_env import ModuleToEnvPipeline
from ray.rllib.core import (
    COMPONENT_ENV_RUNNER,
    COMPONENT_ENV_TO_MODULE_CONNECTOR,
    COMPONENT_MODULE_TO_ENV_CONNECTOR,
    COMPONENT_LEARNER_GROUP,
    COMPONENT_LEARNER,
    COMPONENT_RL_MODULE,
)
from ray.rllib.core.columns import Columns
from ray.rllib.core.rl_module.rl_module import RLModule
from ray.rllib.env.single_agent_episode import SingleAgentEpisode
from ray.rllib.env.multi_agent_episode import MultiAgentEpisode
from ray.rllib.utils.framework import try_import_torch

# my custom multi-agent env with action masks
from env.dispatch_env_cubes_ray import DispatchEnvCubesRay

torch, _ = try_import_torch()
CHECKPOINT_DIR = os.path.abspath('some_path_to_my_checkpoint')
env = DispatchEnvCubesRay(render_mode='human')


print("Restore env-to-module connector from checkpoint ...", end="")
env_to_module = EnvToModulePipeline.from_checkpoint(
    os.path.join(
        CHECKPOINT_DIR,
        COMPONENT_ENV_RUNNER,
        COMPONENT_ENV_TO_MODULE_CONNECTOR,
    )
)
print(" ok")

print("Restore RLModule from checkpoint ...", end="")
# Create RLModule from a checkpoint.
rl_module = RLModule.from_checkpoint(
    os.path.join(
        CHECKPOINT_DIR,
        COMPONENT_LEARNER_GROUP,
        COMPONENT_LEARNER,
        COMPONENT_RL_MODULE,
        'shared_policy',
    )
)
print(" ok")

print("Restore module-to-env connector from checkpoint ...", end="")
module_to_env = ModuleToEnvPipeline.from_checkpoint(
    os.path.join(
        CHECKPOINT_DIR,
        COMPONENT_ENV_RUNNER,
        COMPONENT_MODULE_TO_ENV_CONNECTOR,
    )
)
print(" ok")

num_episodes = 0

observations, _ = env.reset()
# input_dict = {Columns.OBS: observations}
# out = rl_module.forward_inference(observations)
episode = MultiAgentEpisode(
    observations=[observations],
    observation_space=env.observation_spaces,
    action_space=env.action_space,
)

while num_episodes < 2:
    shared_data = {}
    input_dict = env_to_module(
        episodes=[episode],  # ConnectorV2 pipelines operate on lists of episodes.
        rl_module=rl_module,
        explore=False,
        shared_data=shared_data,
    )

    rl_module_out = rl_module.forward_inference(input_dict)

    to_env = module_to_env(
        batch=rl_module_out,
        episodes=[episode],  # ConnectorV2 pipelines operate on lists of episodes.
        rl_module=rl_module,
        explore=False,
        shared_data=shared_data,
    )
    # Send the computed action to the env. Note that the RLModule and the
    # connector pipelines work on batched data (B=1 in this case), whereas the Env
    # is not vectorized here, so we need to use `action[0]`.
    action = to_env.pop(Columns.ACTIONS)[0]
    obs, reward, terminated, truncated, _ = env.step(action)
    # Keep our `SingleAgentEpisode` instance updated at all times.
    episode.add_env_step(
        obs,
        action,
        reward,
        terminated=terminated,
        truncated=truncated,
        # Same here: [0] b/c RLModule output is batched (w/ B=1).
        extra_model_outputs={k: v[0] for k, v in to_env.items()},
    )

    # Is the episode `done`? -> Reset.
    if episode.is_done:
        print(f"Episode done: Total reward = {episode.get_return()}")
        obs, info = env.reset()
        episode = SingleAgentEpisode(
            observations=[obs],
            observation_space=env.observation_space,
            action_space=env.action_space,
        )
        num_episodes += 1

for some reason it crashes on line input_dict = env_to_module(...
here’s console output

2025-04-24 15:01:25,355 WARNING deprecation.py:50 -- DeprecationWarning: `ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module.PPOTorchRLModule` has been deprecated. Use `ray.rllib.algorithms.ppo.torch.default_ppo_torch_rl_module.DefaultPPOTorchRLModule` instead. This will raise an error in the future!
Restore env-to-module connector from checkpoint ... ok
2025-04-24 15:01:25,438 WARNING deprecation.py:50 -- DeprecationWarning: `RLModule(config=[RLModuleConfig object])` has been deprecated. Use `RLModule(observation_space=.., action_space=.., inference_only=.., model_config=.., catalog_class=..)` instead. This will raise an error in the future!
Restore RLModule from checkpoint ... ok
Restore module-to-env connector from checkpoint ... ok
Traceback (most recent call last):
  File "/virtual-dispatcher-environment-model/inference_pipeline_ray_v2.py", line 83, in <module>
    input_dict = env_to_module(
                 ^^^^^^^^^^^^^^
  File "/virtual-dispatcher-environment-model/.venv/lib/python3.12/site-packages/ray/rllib/connectors/env_to_module/env_to_module_pipeline.py", line 38, in __call__
    ret = super().__call__(
          ^^^^^^^^^^^^^^^^^
  File "/virtual-dispatcher-environment-model/.venv/lib/python3.12/site-packages/ray/rllib/connectors/connector_pipeline_v2.py", line 123, in __call__
    batch = connector(
            ^^^^^^^^^^
  File "/virtual-dispatcher-environment-model/.venv/lib/python3.12/site-packages/ray/rllib/connectors/common/agent_to_module_mapping.py", line 180, in __call__
    if rl_module is not None and column in rl_module:
                                 ^^^^^^^^^^^^^^^^^^^
TypeError: argument of type 'ActionMaskingTorchRLModule' is not iterable

Approach 2: Direct RLModule Usage

I also tried a more direct approach using the RLModule:

import os
import torch
from torch.distributions import Categorical
from ray.rllib.env.multi_agent_env import MultiAgentEnv
from env.dispatch_env_cubes_ray import DispatchEnvCubesRay
from ray.tune.registry import register_env
from ray.rllib.algorithms.algorithm import Algorithm


CHECKPOINT_DIR = os.path.abspath('some_path_to_my_checkpoint')
register_env('cubes_ray', lambda _: DispatchEnvCubesRay(render_mode='human'))

algo = Algorithm.from_checkpoint(CHECKPOINT_DIR)

env: MultiAgentEnv = DispatchEnvCubesRay(render_mode='human')
rl_module = algo.get_module("shared_policy")
rl_module.eval()

obs, _ = env.reset()
all_rewards = []
sum_rewards = 0
FLOAT_MIN = torch.finfo(torch.float32).min

for step in range(1000):
    actions = {}

    for agent_id, agent_obs in obs.items():

        obs = agent_obs['observations']
        agent_action_mask = agent_obs['action_mask']

        obs_tensor = torch.tensor(obs, dtype=torch.float32).unsqueeze(0)
        action_mask_tensor = torch.tensor(agent_action_mask, dtype=torch.float32)

        with torch.no_grad():
            action_out = rl_module.forward_inference({"obs": obs_tensor})

        logits = action_out["action_dist_inputs"]

        inf_mask = torch.clamp(torch.log(action_mask_tensor), min=FLOAT_MIN)
        masked_logits = logits + inf_mask

        dist = Categorical(logits=masked_logits)
        action = dist.sample().item()
        actions[agent_id] = action

    obs, rewards, terminateds, truncateds, infos = env.step(actions)
    all_rewards.append(rewards)
    print(f"[STEP {step}] Rewards: {rewards}, Action: {actions}")
    sum_rewards += rewards['gruz_1']

    if terminateds.get("__all__", True) or truncateds.get("__all__", True):
        # print(all_rewards)
        print(f'Sum_rewords={sum_rewards}')
        break

and this one fails on the line action_out = rl_module.forward_inference({"obs": obs_tensor})
console output:

Traceback (most recent call last):
  File "/virtual-dispatcher-environment-model/inference_pipeline_ray.py", line 38, in <module>
    action_out = rl_module.forward_inference({"obs": obs_tensor})
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/virtual-dispatcher-environment-model/.venv/lib/python3.12/site-packages/ray/rllib/core/rl_module/rl_module.py", line 564, in forward_inference
    return self._forward_inference(batch, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/virtual-dispatcher-environment-model/.venv/lib/python3.12/site-packages/ray/rllib/examples/rl_modules/classes/action_masking_rlm.py", line 94, in _forward_inference
    action_mask, batch = self._preprocess_batch(batch)
                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/virtual-dispatcher-environment-model/.venv/lib/python3.12/site-packages/ray/rllib/examples/rl_modules/classes/action_masking_rlm.py", line 145, in _preprocess_batch
    self._check_batch(batch)
  File "/virtual-dispatcher-environment-model/.venv/lib/python3.12/site-packages/ray/rllib/examples/rl_modules/classes/action_masking_rlm.py", line 192, in _check_batch
    if "action_mask" not in batch[Columns.OBS]:
       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/virtual-dispatcher-environment-model/.venv/lib/python3.12/site-packages/torch/_tensor.py", line 1225, in __contains__
    raise RuntimeError(
RuntimeError: Tensor.__contains__ only supports Tensor or scalar, but you passed in a <class 'str'>.

for me it seems like ActionMaskingTorchRLModule want to handle the action masks by itself but can’t really do so. . Or maybe i did something wrong (where are not much examples with MARL and action masks :slightly_frowning_face: )

Questions

  1. Are there any examples of implementing inference with MARL and action masks in RLlib?
  2. What is the correct way to handle action masks during inference with ActionMaskingTorchRLModule?
  3. Is there something fundamentally wrong with my approach?

okay, somehow I’ve managed this myself. The second approach should look like this

for step in range(1000):
    actions = {}

    for agent_id, agent_obs in obs.items():

        obs = agent_obs['observations']
        agent_action_mask = agent_obs['action_mask']

        obs_tensor = torch.tensor(obs, dtype=torch.float32).unsqueeze(0)
        action_mask_tensor = torch.tensor(agent_action_mask, dtype=torch.float32)

        with torch.no_grad():
            action_out = rl_module.forward_inference({"obs": {'observations': obs_tensor,
                                                              'action_mask': action_mask_tensor}})

        logits = action_out["action_dist_inputs"]

        dist = Categorical(logits=logits)
        action = dist.sample().item()
        actions[agent_id] = action

    obs, rewards, terminateds, truncateds, infos = env.step(actions)
    all_rewards.append(rewards)
    print(f"[STEP {step}] Rewards: {rewards}, Action: {actions}")
    sum_rewards += rewards['gruz_1']

    if terminateds.get("__all__", True) or truncateds.get("__all__", True):
        # print(all_rewards)
        print(f'Sum_rewords={sum_rewards}')
        break

and it really works for me

PS still not sure if it works correctly. The “not crashing with errors” not always means that everything is perfect tbh