Context
I’m trying to set up inference for a multi-agent environment with action masking. While the learning process appears to work well, I’m encountering issues during inference setup. I’ve attempted two different approaches, both resulting in errors.
Approach 1: Using Connector Pipelines
I tried following the official example, but encountered issues. Here’s my implementation:
import os
from ray.rllib.connectors.env_to_module import EnvToModulePipeline
from ray.rllib.connectors.module_to_env import ModuleToEnvPipeline
from ray.rllib.core import (
COMPONENT_ENV_RUNNER,
COMPONENT_ENV_TO_MODULE_CONNECTOR,
COMPONENT_MODULE_TO_ENV_CONNECTOR,
COMPONENT_LEARNER_GROUP,
COMPONENT_LEARNER,
COMPONENT_RL_MODULE,
)
from ray.rllib.core.columns import Columns
from ray.rllib.core.rl_module.rl_module import RLModule
from ray.rllib.env.single_agent_episode import SingleAgentEpisode
from ray.rllib.env.multi_agent_episode import MultiAgentEpisode
from ray.rllib.utils.framework import try_import_torch
# my custom multi-agent env with action masks
from env.dispatch_env_cubes_ray import DispatchEnvCubesRay
torch, _ = try_import_torch()
CHECKPOINT_DIR = os.path.abspath('some_path_to_my_checkpoint')
env = DispatchEnvCubesRay(render_mode='human')
print("Restore env-to-module connector from checkpoint ...", end="")
env_to_module = EnvToModulePipeline.from_checkpoint(
os.path.join(
CHECKPOINT_DIR,
COMPONENT_ENV_RUNNER,
COMPONENT_ENV_TO_MODULE_CONNECTOR,
)
)
print(" ok")
print("Restore RLModule from checkpoint ...", end="")
# Create RLModule from a checkpoint.
rl_module = RLModule.from_checkpoint(
os.path.join(
CHECKPOINT_DIR,
COMPONENT_LEARNER_GROUP,
COMPONENT_LEARNER,
COMPONENT_RL_MODULE,
'shared_policy',
)
)
print(" ok")
print("Restore module-to-env connector from checkpoint ...", end="")
module_to_env = ModuleToEnvPipeline.from_checkpoint(
os.path.join(
CHECKPOINT_DIR,
COMPONENT_ENV_RUNNER,
COMPONENT_MODULE_TO_ENV_CONNECTOR,
)
)
print(" ok")
num_episodes = 0
observations, _ = env.reset()
# input_dict = {Columns.OBS: observations}
# out = rl_module.forward_inference(observations)
episode = MultiAgentEpisode(
observations=[observations],
observation_space=env.observation_spaces,
action_space=env.action_space,
)
while num_episodes < 2:
shared_data = {}
input_dict = env_to_module(
episodes=[episode], # ConnectorV2 pipelines operate on lists of episodes.
rl_module=rl_module,
explore=False,
shared_data=shared_data,
)
rl_module_out = rl_module.forward_inference(input_dict)
to_env = module_to_env(
batch=rl_module_out,
episodes=[episode], # ConnectorV2 pipelines operate on lists of episodes.
rl_module=rl_module,
explore=False,
shared_data=shared_data,
)
# Send the computed action to the env. Note that the RLModule and the
# connector pipelines work on batched data (B=1 in this case), whereas the Env
# is not vectorized here, so we need to use `action[0]`.
action = to_env.pop(Columns.ACTIONS)[0]
obs, reward, terminated, truncated, _ = env.step(action)
# Keep our `SingleAgentEpisode` instance updated at all times.
episode.add_env_step(
obs,
action,
reward,
terminated=terminated,
truncated=truncated,
# Same here: [0] b/c RLModule output is batched (w/ B=1).
extra_model_outputs={k: v[0] for k, v in to_env.items()},
)
# Is the episode `done`? -> Reset.
if episode.is_done:
print(f"Episode done: Total reward = {episode.get_return()}")
obs, info = env.reset()
episode = SingleAgentEpisode(
observations=[obs],
observation_space=env.observation_space,
action_space=env.action_space,
)
num_episodes += 1
for some reason it crashes on line input_dict = env_to_module(...
here’s console output
2025-04-24 15:01:25,355 WARNING deprecation.py:50 -- DeprecationWarning: `ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module.PPOTorchRLModule` has been deprecated. Use `ray.rllib.algorithms.ppo.torch.default_ppo_torch_rl_module.DefaultPPOTorchRLModule` instead. This will raise an error in the future!
Restore env-to-module connector from checkpoint ... ok
2025-04-24 15:01:25,438 WARNING deprecation.py:50 -- DeprecationWarning: `RLModule(config=[RLModuleConfig object])` has been deprecated. Use `RLModule(observation_space=.., action_space=.., inference_only=.., model_config=.., catalog_class=..)` instead. This will raise an error in the future!
Restore RLModule from checkpoint ... ok
Restore module-to-env connector from checkpoint ... ok
Traceback (most recent call last):
File "/virtual-dispatcher-environment-model/inference_pipeline_ray_v2.py", line 83, in <module>
input_dict = env_to_module(
^^^^^^^^^^^^^^
File "/virtual-dispatcher-environment-model/.venv/lib/python3.12/site-packages/ray/rllib/connectors/env_to_module/env_to_module_pipeline.py", line 38, in __call__
ret = super().__call__(
^^^^^^^^^^^^^^^^^
File "/virtual-dispatcher-environment-model/.venv/lib/python3.12/site-packages/ray/rllib/connectors/connector_pipeline_v2.py", line 123, in __call__
batch = connector(
^^^^^^^^^^
File "/virtual-dispatcher-environment-model/.venv/lib/python3.12/site-packages/ray/rllib/connectors/common/agent_to_module_mapping.py", line 180, in __call__
if rl_module is not None and column in rl_module:
^^^^^^^^^^^^^^^^^^^
TypeError: argument of type 'ActionMaskingTorchRLModule' is not iterable
Approach 2: Direct RLModule Usage
I also tried a more direct approach using the RLModule:
import os
import torch
from torch.distributions import Categorical
from ray.rllib.env.multi_agent_env import MultiAgentEnv
from env.dispatch_env_cubes_ray import DispatchEnvCubesRay
from ray.tune.registry import register_env
from ray.rllib.algorithms.algorithm import Algorithm
CHECKPOINT_DIR = os.path.abspath('some_path_to_my_checkpoint')
register_env('cubes_ray', lambda _: DispatchEnvCubesRay(render_mode='human'))
algo = Algorithm.from_checkpoint(CHECKPOINT_DIR)
env: MultiAgentEnv = DispatchEnvCubesRay(render_mode='human')
rl_module = algo.get_module("shared_policy")
rl_module.eval()
obs, _ = env.reset()
all_rewards = []
sum_rewards = 0
FLOAT_MIN = torch.finfo(torch.float32).min
for step in range(1000):
actions = {}
for agent_id, agent_obs in obs.items():
obs = agent_obs['observations']
agent_action_mask = agent_obs['action_mask']
obs_tensor = torch.tensor(obs, dtype=torch.float32).unsqueeze(0)
action_mask_tensor = torch.tensor(agent_action_mask, dtype=torch.float32)
with torch.no_grad():
action_out = rl_module.forward_inference({"obs": obs_tensor})
logits = action_out["action_dist_inputs"]
inf_mask = torch.clamp(torch.log(action_mask_tensor), min=FLOAT_MIN)
masked_logits = logits + inf_mask
dist = Categorical(logits=masked_logits)
action = dist.sample().item()
actions[agent_id] = action
obs, rewards, terminateds, truncateds, infos = env.step(actions)
all_rewards.append(rewards)
print(f"[STEP {step}] Rewards: {rewards}, Action: {actions}")
sum_rewards += rewards['gruz_1']
if terminateds.get("__all__", True) or truncateds.get("__all__", True):
# print(all_rewards)
print(f'Sum_rewords={sum_rewards}')
break
and this one fails on the line action_out = rl_module.forward_inference({"obs": obs_tensor})
console output:
Traceback (most recent call last):
File "/virtual-dispatcher-environment-model/inference_pipeline_ray.py", line 38, in <module>
action_out = rl_module.forward_inference({"obs": obs_tensor})
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/virtual-dispatcher-environment-model/.venv/lib/python3.12/site-packages/ray/rllib/core/rl_module/rl_module.py", line 564, in forward_inference
return self._forward_inference(batch, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/virtual-dispatcher-environment-model/.venv/lib/python3.12/site-packages/ray/rllib/examples/rl_modules/classes/action_masking_rlm.py", line 94, in _forward_inference
action_mask, batch = self._preprocess_batch(batch)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/virtual-dispatcher-environment-model/.venv/lib/python3.12/site-packages/ray/rllib/examples/rl_modules/classes/action_masking_rlm.py", line 145, in _preprocess_batch
self._check_batch(batch)
File "/virtual-dispatcher-environment-model/.venv/lib/python3.12/site-packages/ray/rllib/examples/rl_modules/classes/action_masking_rlm.py", line 192, in _check_batch
if "action_mask" not in batch[Columns.OBS]:
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/virtual-dispatcher-environment-model/.venv/lib/python3.12/site-packages/torch/_tensor.py", line 1225, in __contains__
raise RuntimeError(
RuntimeError: Tensor.__contains__ only supports Tensor or scalar, but you passed in a <class 'str'>.
for me it seems like ActionMaskingTorchRLModule
want to handle the action masks by itself but can’t really do so. . Or maybe i did something wrong (where are not much examples with MARL and action masks )
Questions
- Are there any examples of implementing inference with MARL and action masks in RLlib?
- What is the correct way to handle action masks during inference with
ActionMaskingTorchRLModule
? - Is there something fundamentally wrong with my approach?