Do multi-agent environments need to specify an "action_space"?

I am trying to write a multi-agent environment with action masking using the new stack following the examples in the API documentation.

I have written a program that trains a very minimal random environment using the PPO algorithm. (I realize that there are some examples in ray.rlib.examples.multi_agent but these are all run using the extremely complicated run_rllib_example_script_experiment function that I know will be a nightmare to debug if I make one little mistake.) I want an example that is easy for myself and others to follow.

Here is what I have so far.

import random
from typing import Any, Optional

from gymnasium import Env
from gymnasium.spaces import Discrete, Box, Dict, MultiBinary
from numpy import ndarray
from ray.rllib.algorithms import PPOConfig
from ray.rllib.core.rl_module import MultiRLModuleSpec
from ray.rllib.core.rl_module.rl_module import RLModuleSpec
from ray.rllib.examples.rl_modules.classes.action_masking_rlm import (
    ActionMaskingTorchRLModule,
)


class MaskedMultiAgentRandomEnvironment(Env):
    """A custom OpenAI Gym multi-agent environment with action masking.

    This environment consists of 3 simultaneously-interacting agents with 5 continuous observations
    and 2 discrete actions each which can be masked. Each agent returns random observations, action
    masks, and rewards and has a 10% chance of termination after each step.
    """

    AGENTS = 3
    OBSERVATIONS = 5
    ACTIONS = 2

    def __init__(self, _: dict[str, Any] = ()):
        super().__init__()
        self.agents = self.possible_agents
        self.observation_spaces = {}
        self.action_spaces = {}
        for agent in self.agents:
            self.observation_spaces[agent] = Dict(
                {
                    "observations": (Box(0, 1, (self.OBSERVATIONS,))),
                    "action_mask": MultiBinary(self.ACTIONS),
                }
            )
            self.action_spaces[agent] = Discrete(self.ACTIONS)

    @property
    def possible_agents(self) -> set[str]:
        return {f"Agent {i}" for i in range(self.AGENTS)}

    def reset(
        self, *, seed: Optional[int] = None, options: Optional[dict[str, Any]] = None
    ) -> tuple[dict, dict[str, Any]]:
        super().reset(seed=seed, options=options)
        self.agents = self.possible_agents
        observations = self.random_agent_observations()
        return observations, {}

    def step(self, actions: dict[str, int]) -> tuple[
        dict[str, dict[str, ndarray]],  # Observations
        dict[str, float],  # Rewards
        dict[str, bool],  # Terminations
        dict[str, bool],  # Truncations
        dict[str, Any],  # Info
    ]:
        observations = self.random_agent_observations()
        rewards = self.random_agent_rewards()
        terminations = self.random_agent_terminations()
        self.agents -= set(agent for agent, done in terminations.items() if done)
        return observations, rewards, terminations, {}, {}

    def random_agent_observations(self) -> dict[str, dict[str, ndarray]]:
        return {agent: self.observation_space.sample() for agent in self.agents}

    def random_agent_rewards(self) -> dict[str, float]:
        return {agent: random.random() for agent in self.agents}

    def random_agent_terminations(self) -> dict[str, bool]:
        return {
            agent: random.choices([True, False], weights=[0.1, 0.9])[0]
            for agent in self.agents
        }


POLICY_ID = "Shared Policy"


def train_multi_agent_masked_algorithm():
    """
    Train a model for the multi-agent random environment using the PPO algorithm with default
    parameters and a single policy shared between all the agents.
    """
    module_spec = RLModuleSpec(module_class=ActionMaskingTorchRLModule)
    base_config = (
        PPOConfig()
        .environment(
            env=MaskedMultiAgentRandomEnvironment,
            disable_env_checking=True,
        )
        .multi_agent(policy_mapping_fn=lambda _, __, **___: POLICY_ID)
        .rl_module(
            rl_module_spec=MultiRLModuleSpec(rl_module_specs={POLICY_ID: module_spec})
        )
    )
    algo = base_config.build_algo()
    print(algo)
    result = algo.train()
    print(result)


if __name__ == "__main__":
    train_multi_agent_masked_algorithm()

It fails in gymnasium.wrappers.common.PassiveEnvChecker because the environment does not specify an action_space member.

2025-03-02 12:25:26,525	INFO worker.py:1841 -- Started a local Ray instance.
2025-03-02 12:25:28,369	ERROR actor_manager.py:833 -- Ray error (The actor died because of an error raised in its creation task, ray::SingleAgentEnvRunner.__init__() (pid=31569, ip=127.0.0.1, actor_id=4222258f312567add75fc51e01000000, repr=<ray.rllib.env.single_agent_env_runner.SingleAgentEnvRunner object at 0x12fe72570>)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mcneill/miniforge3/envs/ray/lib/python3.12/site-packages/ray/rllib/env/single_agent_env_runner.py", line 100, in __init__
    self.make_env()
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mcneill/miniforge3/envs/ray/lib/python3.12/site-packages/ray/rllib/env/single_agent_env_runner.py", line 654, in make_env
    gym.make_vec(
  File "/Users/mcneill/miniforge3/envs/ray/lib/python3.12/site-packages/gymnasium/envs/registration.py", line 920, in make_vec
    env = gym.vector.SyncVectorEnv(
          ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mcneill/miniforge3/envs/ray/lib/python3.12/site-packages/gymnasium/vector/sync_vector_env.py", line 96, in __init__
    self.envs = [env_fn() for env_fn in env_fns]
                 ^^^^^^^^
  File "/Users/mcneill/miniforge3/envs/ray/lib/python3.12/site-packages/gymnasium/envs/registration.py", line 905, in create_single_env
    single_env = make(env_spec, **env_spec_kwargs.copy())
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mcneill/miniforge3/envs/ray/lib/python3.12/site-packages/gymnasium/envs/registration.py", line 805, in make
    env = gym.wrappers.PassiveEnvChecker(env)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mcneill/miniforge3/envs/ray/lib/python3.12/site-packages/gymnasium/wrappers/common.py", line 261, in __init__
    raise AttributeError(
AttributeError: The environment must specify an action space. https://gymnasium.farama.org/introduction/create_custom_env/), taking actor 1 out of service.

According to the RLlib multi-agent documentation, multi-agent environments are supposed to specify an action_spaces member, not action_space.

Why is this failing?

From Reddit and various RLlib support groups it appears that I’m not the only one blocked on issues of multi agents and action masking. The examples aren’t well-documented and the tiniest modification launches you into an extremely difficult debugging task. I’m trying to help by writing simple, well-documented, and up-to-date examples that anyone can follow.

I have a masked single-agent example in the Ray Masked project that I think will be helpful to members of the Ray community. It took me a few weeks of debugging to produce it.

I am confident that with a few more weeks of debugging I could figure the multi-agent case out on my own as well, but it would be really helpful if I could get some guidance from a Ray developer. I think this would be an easy way to unblock multiple people.

I’m working on the masked multi-agent example on the multi-agent branch of the project.

My leeriness of run_rllib_example_script_experiment is unjustified. It is a bit confusing, but it does enable you to step through code that runs on workers in the debugger.

@wpm thanks for raising this issue. Afaics from the code above it needs our MultiAgentEnv as a parent. I see you implemented it in your branch. Running the multi-agent environment that is derived from the MultiAgentEnv should not run into this error, anymore. Can you confirm this?

In regard to your custom environment. I guess you will need to define the __all__ key in the terminateds and truncateds dictionaries b/c these are checked for ending a MultiAgentEpisode.

Thanks so much for your response.

I figured out that I needed to be deriving from MultiAgentEnv. My revised code is in the multi-agent branch of my Ray Masked project.

This is my current version of masked multi-agent training.

import random
import re
from typing import Any, Optional

from gymnasium.spaces import Discrete, Box, Dict, MultiBinary
from numpy import ndarray
from ray.rllib.algorithms import PPOConfig, AlgorithmConfig
from ray.rllib.core.rl_module import MultiRLModuleSpec
from ray.rllib.core.rl_module.rl_module import RLModuleSpec
from ray.rllib.env.multi_agent_env import MultiAgentEnv
from ray.rllib.examples.rl_modules.classes.action_masking_rlm import (
    ActionMaskingTorchRLModule,
)


class MaskedMultiAgentRandomEnvironment(MultiAgentEnv):
    """A custom OpenAI Gym multi-agent environment with action masking.

    This environment consists of 3 simultaneously-interacting agents with 5 continuous observations
    and 2 discrete actions each which can be masked. Each agent returns random observations, action
    masks, and rewards and has a 10% chance of termination after each step.
    """

    AGENTS = 3
    OBSERVATIONS = 5
    ACTIONS = 2

    def __init__(self, _: dict[str, Any] = ()):
        super().__init__()
        self.agents = self.possible_agents[:]
        self.observation_spaces = {}
        self.action_spaces = {}
        for agent in self.agents:
            self.observation_spaces[agent] = Dict(
                {
                    "observations": (Box(0, 1, (self.OBSERVATIONS,))),
                    "action_mask": MultiBinary(self.ACTIONS),
                }
            )
            self.action_spaces[agent] = Discrete(self.ACTIONS)

    def reset(
        self, *, seed: Optional[int] = None, options: Optional[dict[str, Any]] = None
    ) -> tuple[dict, dict[str, Any]]:
        super().reset(seed=seed, options=options)
        self.agents = self.possible_agents[:]
        observations = self.random_agent_observations()
        return observations, {}

    def step(self, actions: dict[str, int]) -> tuple[
        dict[str, dict[str, ndarray]],  # Observations
        dict[str, float],  # Rewards
        dict[str, bool],  # Terminations
        dict[str, bool],  # Truncations
        dict[str, Any],  # Info
    ]:
        observations = self.random_agent_observations()
        rewards = self.random_agent_rewards()
        terminations = self.random_agent_terminations()
        self.agents = self.remaining_agents(terminations)
        return observations, rewards, terminations, {}, {}

    @property
    def possible_agents(self) -> list[str]:
        return [f"Agent {i}" for i in range(self.AGENTS)]

    def remaining_agents(self, agent_done: dict[str, bool]) -> list[str]:
        """
        :param agent_done: a dictionary mapping agent names to whether they have been terminated or not
        :return: a list of agents that are still active
        """
        done = set(agent for agent, done in agent_done.items() if done)
        return sorted(
            set(self.agents) - done,
            key=lambda agent: int(re.match(r"Agent (\d+)", agent).group(1)),
        )

    def random_agent_observations(self) -> dict[str, dict[str, ndarray]]:
        """
        Generate a random observation and action mask for each active agent.
        """
        return {agent: self.observation_spaces[agent].sample() for agent in self.agents}

    def random_agent_rewards(self) -> dict[str, float]:
        """
        Generate a randon reward for each agent between 0 and 1.
        """
        return {agent: random.random() for agent in self.agents}

    def random_agent_terminations(self) -> dict[str, bool]:
        """
        Terminate actives agents at a rate of 10%. Return terminated agents as well as those that were already
        terminated.
        """
        return {
            agent: (
                random.choices([True, False], weights=[0.1, 0.9])[0]
                if agent in self.agents
                else True
            )
            for agent in self.possible_agents
        }


POLICY_ID = "Shared Policy"


def multi_agent_config() -> AlgorithmConfig:
    module_spec = RLModuleSpec(module_class=ActionMaskingTorchRLModule)
    return (
        PPOConfig()
        .environment(env=MaskedMultiAgentRandomEnvironment)
        .multi_agent(
            policies={POLICY_ID}, policy_mapping_fn=lambda _, __, **___: POLICY_ID
        )
        .rl_module(
            rl_module_spec=MultiRLModuleSpec(rl_module_specs={POLICY_ID: module_spec})
        )
    )


def train_multi_agent_masked_algorithm():
    """
    Train a model for the multi-agent random environment using the PPO algorithm with default
    parameters and a single policy shared between all the agents.
    """
    base_config = multi_agent_config()
    algo = base_config.build_algo()
    print(algo)
    result = algo.train()
    print(result)


if __name__ == "__main__":
    train_multi_agent_masked_algorithm()

When I run it I see the following error.

2025-03-04 13:30:51,325	INFO worker.py:1841 -- Started a local Ray instance.
(MultiAgentEnvRunner pid=2376) 2025-03-04 13:30:53,836	WARNING deprecation.py:50 -- DeprecationWarning: `ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module.PPOTorchRLModule` has been deprecated. Use `ray.rllib.algorithms.ppo.torch.default_ppo_torch_rl_module.DefaultPPOTorchRLModule` instead. This will raise an error in the future!
(MultiAgentEnvRunner pid=2376) 2025-03-04 13:30:53,903	WARNING deprecation.py:50 -- DeprecationWarning: `RLModule(config=[RLModuleConfig object])` has been deprecated. Use `RLModule(observation_space=.., action_space=.., inference_only=.., model_config=.., catalog_class=..)` instead. This will raise an error in the future!
2025-03-04 13:30:53,990	WARNING deprecation.py:50 -- DeprecationWarning: `RLModule(config=[RLModuleConfig object])` has been deprecated. Use `RLModule(observation_space=.., action_space=.., inference_only=.., model_config=.., catalog_class=..)` instead. This will raise an error in the future!
PPO(env=<class '__main__.MaskedMultiAgentRandomEnvironment'>; env-runners=2; learners=0; multi-agent=True)
(MultiAgentEnvRunner pid=2376) 2025-03-04 13:30:54,857	ERROR actor_manager.py:187 -- Worker exception caught during `apply()`: '__all__'
(MultiAgentEnvRunner pid=2376) Traceback (most recent call last):
(MultiAgentEnvRunner pid=2376)   File "/Users/mcneill/miniforge3/envs/ray/lib/python3.12/site-packages/ray/rllib/utils/actor_manager.py", line 183, in apply
(MultiAgentEnvRunner pid=2376)     return func(self, *args, **kwargs)
(MultiAgentEnvRunner pid=2376)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
(MultiAgentEnvRunner pid=2376)   File "/Users/mcneill/miniforge3/envs/ray/lib/python3.12/site-packages/ray/rllib/execution/rollout_ops.py", line 110, in <lambda>
(MultiAgentEnvRunner pid=2376)     else (lambda w: (w.sample(**random_action_kwargs), w.get_metrics()))
(MultiAgentEnvRunner pid=2376)                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(MultiAgentEnvRunner pid=2376)   File "/Users/mcneill/miniforge3/envs/ray/lib/python3.12/site-packages/ray/util/tracing/tracing_helper.py", line 463, in _resume_span
(MultiAgentEnvRunner pid=2376)     return method(self, *_args, **_kwargs)
(MultiAgentEnvRunner pid=2376)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(MultiAgentEnvRunner pid=2376)   File "/Users/mcneill/miniforge3/envs/ray/lib/python3.12/site-packages/ray/rllib/env/multi_agent_env_runner.py", line 208, in sample
(MultiAgentEnvRunner pid=2376)     samples = self._sample(
(MultiAgentEnvRunner pid=2376)               ^^^^^^^^^^^^^
(MultiAgentEnvRunner pid=2376)   File "/Users/mcneill/miniforge3/envs/ray/lib/python3.12/site-packages/ray/util/tracing/tracing_helper.py", line 463, in _resume_span
(MultiAgentEnvRunner pid=2376)     return method(self, *_args, **_kwargs)
(MultiAgentEnvRunner pid=2376)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(MultiAgentEnvRunner pid=2376)   File "/Users/mcneill/miniforge3/envs/ray/lib/python3.12/site-packages/ray/rllib/env/multi_agent_env_runner.py", line 347, in _sample
(MultiAgentEnvRunner pid=2376)     results = self._try_env_step(actions_for_env)
(MultiAgentEnvRunner pid=2376)               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(MultiAgentEnvRunner pid=2376)   File "/Users/mcneill/miniforge3/envs/ray/lib/python3.12/site-packages/ray/util/tracing/tracing_helper.py", line 463, in _resume_span
(MultiAgentEnvRunner pid=2376)     return method(self, *_args, **_kwargs)
(MultiAgentEnvRunner pid=2376)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(MultiAgentEnvRunner pid=2376)   File "/Users/mcneill/miniforge3/envs/ray/lib/python3.12/site-packages/ray/rllib/env/env_runner.py", line 201, in _try_env_step
(MultiAgentEnvRunner pid=2376)     raise e
(MultiAgentEnvRunner pid=2376)   File "/Users/mcneill/miniforge3/envs/ray/lib/python3.12/site-packages/ray/rllib/env/env_runner.py", line 186, in _try_env_step
(MultiAgentEnvRunner pid=2376)     results = self.env.step(actions)
(MultiAgentEnvRunner pid=2376)               ^^^^^^^^^^^^^^^^^^^^^^
(MultiAgentEnvRunner pid=2376)   File "/Users/mcneill/miniforge3/envs/ray/lib/python3.12/site-packages/ray/rllib/env/vector/sync_vector_multi_agent_env.py", line 126, in step
(MultiAgentEnvRunner pid=2376)     np.array([t["__all__"] for t in self._terminations]),
(MultiAgentEnvRunner pid=2376)               ~^^^^^^^^^^^
(MultiAgentEnvRunner pid=2376) KeyError: '__all__'
(raylet) A worker died or was killed while executing a task by an unexpected system error. To troubleshoot the problem, check the logs for the dead worker. RayTask ID: ffffffffffffffff9bffec7ffcb9301b4dbcb3b601000000 Worker ID: 30fe9898b2aa7f3bae16f4664cda41b872c505aeab7e8c4d2e84f569 Node ID: 16d3a09e9042284ec005978e8f23fced70fb1912e03cae20d25d3dd9 Worker IP address: 127.0.0.1 Worker port: 49768 Worker PID: 2379 Worker exit type: SYSTEM_ERROR Worker exit detail: Worker exits unexpectedly. Worker exits with an exit code 1.

So it’s failing because I’m not including the special __all__ agent flag in the keys of the returned dictionary and you also mention this flag, but the documentation for the step return values says the following:

Use the special agent ID __all__ in the termination dicts and/or truncation dicts to indicate that the episode should end for all agent IDs, regardless of which agents are still active at that point. RLlib automatically terminates all agents in this case and ends the episode.

I interpret this to mean that I should not return the __all__ flag if any agents are still running.

I don’t understand what this return value should look like. Can you provide an example return value and/or show how to rewrite the following function?

    def random_agent_terminations(self) -> dict[str, bool]:
        """
        Terminate actives agents at a rate of 10%. Return terminated agents as well as those that were already
        terminated.
        """
        return {
            agent: (
                random.choices([True, False], weights=[0.1, 0.9])[0]
                if agent in self.agents
                else True
            )
            for agent in self.possible_agents
        }

I can get training to complete if I always include __all__ in the returned terminateds dictionary like so:

    def random_agent_terminations(self) -> dict[str, bool]:
        """
        Terminate actives agents at a rate of 10%. Return terminated agents as well as those that were already
        terminated.
        """
        terminations = {
            agent: (
                random.choices([True, False], weights=[0.1, 0.9])[0]
                if agent in self.agents
                else True
            )
            for agent in self.possible_agents
        }
        terminations["__all__"] = all(terminations.values())
        return terminations

The documentation led me to believe that __all__ was just a shortcut for when all the agents happened to be in the same state, but it appears that it is always necessary. How am I supposed to use __all__?

Hey @wpm , great that you figured this out! You are almost there. So what you want to return from step() are two dictionary for terminateds and truncateds that contain always __all__ in their keys, like

terminateds = {
   "a1": False, 
   "a2": True,
   "__all__": False,
}

Set __all__=True only, if all your agents died (or were truncated). Because we cannot (and should not) track all agents that have ever lived or died in an environment inside the return values of step (think of a MA env that has agents being born and dying at certain points and there could be theoretically infinitely many) we carry instead the __all__ key that tells us when the episode is done.

So have I got my code change right? __all__ is always False unless all agents in MultiAgentEnv.possible_agents have been terminated?

Should I be returning {"__all__": False} for truncation as well? (Even though I get away with returning an empty dictionary here?)

A couple issues. When I put __all__ in the dictionary I get the following warning although everything works fine.

(MultiAgentEnvRunner pid=4791) 2025-03-04 14:13:19,765	ERROR multi_agent_env_runner.py:817 -- The element returned by step, next_obs has agent_ids that are not the names of the agents in the env. 
(MultiAgentEnvRunner pid=4791) AgentIDs in this MultiAgentDict: ['Agent 0', 'Agent 1', 'Agent 2']
(MultiAgentEnvRunner pid=4791) AgentIDs in this env: ['Agent 0', 'Agent 1']. You likely need to add the attribute `agents` to your env, which is a list containing the IDs of agents currently in your env/episode, as well as, `possible_agents`, which is a list of all possible agents that could ever show up in your env.
(MultiAgentEnvRunner pid=4791) Traceback (most recent call last):
(MultiAgentEnvRunner pid=4791)   File "/Users/mcneill/miniforge3/envs/ray/lib/python3.12/site-packages/ray/rllib/env/multi_agent_env_runner.py", line 815, in make_env
(MultiAgentEnvRunner pid=4791)     check_multiagent_environments(env.unwrapped)
(MultiAgentEnvRunner pid=4791)   File "/Users/mcneill/miniforge3/envs/ray/lib/python3.12/site-packages/ray/rllib/utils/pre_checks/env.py", line 74, in check_multiagent_environments
(MultiAgentEnvRunner pid=4791)     _check_if_element_multi_agent_dict(env, next_obs, "step, next_obs")
(MultiAgentEnvRunner pid=4791)   File "/Users/mcneill/miniforge3/envs/ray/lib/python3.12/site-packages/ray/rllib/utils/pre_checks/env.py", line 240, in _check_if_element_multi_agent_dict
(MultiAgentEnvRunner pid=4791)     raise ValueError(error)
(MultiAgentEnvRunner pid=4791) ValueError: The element returned by step, next_obs has agent_ids that are not the names of the agents in the env. 
(MultiAgentEnvRunner pid=4791) AgentIDs in this MultiAgentDict: ['Agent 0', 'Agent 1', 'Agent 2']
(MultiAgentEnvRunner pid=4791) AgentIDs in this env: ['Agent 0', 'Agent 1']. You likely need to add the attribute `agents` to your env, which is a list containing the IDs of agents currently in your env/episode, as well as, `possible_agents`, which is a list of all possible agents that could ever show up in your env.

I try to turn this off by passing the environment flag disable_env_checking=True , but then I get the following assertion that stops the run.

AssertionError: ERROR: When using the `MultiAgentEnvRunner` the environment needs to inherit from `ray.rllib.env.multi_agent_env.MultiAgentEnv`.
Traceback (most recent call last):
  File "/Users/mcneill/src/ray-masked/train_multiagent_masked_algorithm.py", line 137, in <module>
    train_multi_agent_masked_algorithm()
  File "/Users/mcneill/src/ray-masked/train_multiagent_masked_algorithm.py", line 130, in train_multi_agent_masked_algorithm
    algo = base_config.build_algo()
           ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mcneill/miniforge3/envs/ray/lib/python3.12/site-packages/ray/rllib/algorithms/algorithm_config.py", line 958, in build_algo
    return algo_class(
           ^^^^^^^^^^^
  File "/Users/mcneill/miniforge3/envs/ray/lib/python3.12/site-packages/ray/rllib/algorithms/algorithm.py", line 528, in __init__
    super().__init__(
  File "/Users/mcneill/miniforge3/envs/ray/lib/python3.12/site-packages/ray/tune/trainable/trainable.py", line 157, in __init__
    self.setup(copy.deepcopy(self.config))
  File "/Users/mcneill/miniforge3/envs/ray/lib/python3.12/site-packages/ray/rllib/algorithms/algorithm.py", line 631, in setup
    self.env_runner_group = EnvRunnerGroup(
                            ^^^^^^^^^^^^^^^
  File "/Users/mcneill/miniforge3/envs/ray/lib/python3.12/site-packages/ray/rllib/env/env_runner_group.py", line 198, in __init__
    self._setup(
  File "/Users/mcneill/miniforge3/envs/ray/lib/python3.12/site-packages/ray/rllib/env/env_runner_group.py", line 293, in _setup
    self._local_env_runner = self._make_worker(
                             ^^^^^^^^^^^^^^^^^^
  File "/Users/mcneill/miniforge3/envs/ray/lib/python3.12/site-packages/ray/rllib/env/env_runner_group.py", line 1207, in _make_worker
    return self.env_runner_cls(**kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/mcneill/miniforge3/envs/ray/lib/python3.12/site-packages/ray/rllib/env/multi_agent_env_runner.py", line 106, in __init__
    self.make_env()
  File "/Users/mcneill/miniforge3/envs/ray/lib/python3.12/site-packages/ray/rllib/env/multi_agent_env_runner.py", line 820, in make_env
    assert isinstance(self.env.unwrapped, MultiAgentEnv), (
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: ERROR: When using the `MultiAgentEnvRunner` the environment needs to inherit from `ray.rllib.env.multi_agent_env.MultiAgentEnv`.
(MultiAgentEnvRunner pid=5102) 2025-03-04 14:23:22,720	WARNING deprecation.py:50 -- DeprecationWarning: `ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module.PPOTorchRLModule` has been deprecated. Use `ray.rllib.algorithms.ppo.torch.default_ppo_torch_rl_module.DefaultPPOTorchRLModule` instead. This will raise an error in the future!
(MultiAgentEnvRunner pid=5102) Exception raised in creation task: The actor died because of an error raised in its creation task, ray::MultiAgentEnvRunner.__init__() (pid=5102, ip=127.0.0.1, actor_id=8178ae58cd9f33fef9570a7a01000000, repr=<ray.rllib.env.multi_agent_env_runner.MultiAgentEnvRunner object at 0x133b9c140>)
(MultiAgentEnvRunner pid=5102)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(MultiAgentEnvRunner pid=5102)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(MultiAgentEnvRunner pid=5102)   File "/Users/mcneill/miniforge3/envs/ray/lib/python3.12/site-packages/ray/rllib/env/multi_agent_env_runner.py", line 106, in __init__
(MultiAgentEnvRunner pid=5102)     self.make_env()
(MultiAgentEnvRunner pid=5102)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(MultiAgentEnvRunner pid=5102)   File "/Users/mcneill/miniforge3/envs/ray/lib/python3.12/site-packages/ray/rllib/env/multi_agent_env_runner.py", line 820, in make_env
(MultiAgentEnvRunner pid=5102)     assert isinstance(self.env.unwrapped, MultiAgentEnv), (
(MultiAgentEnvRunner pid=5102)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(MultiAgentEnvRunner pid=5102) AssertionError: ERROR: When using the `MultiAgentEnvRunner` the environment needs to inherit from `ray.rllib.env.multi_agent_env.MultiAgentEnv`.
(MultiAgentEnvRunner pid=5098) 2025-03-04 14:23:22,720	WARNING deprecation.py:50 -- DeprecationWarning: `ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module.PPOTorchRLModule` has been deprecated. Use `ray.rllib.algorithms.ppo.torch.default_ppo_torch_rl_module.DefaultPPOTorchRLModule` instead. This will raise an error in the future!
(MultiAgentEnvRunner pid=5098) Exception raised in creation task: The actor died because of an error raised in its creation task, ray::MultiAgentEnvRunner.__init__() (pid=5098, ip=127.0.0.1, actor_id=cd9348b72d647a95b34d53db01000000, repr=<ray.rllib.env.multi_agent_env_runner.MultiAgentEnvRunner object at 0x16fab0320>)
(MultiAgentEnvRunner pid=5098)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(MultiAgentEnvRunner pid=5098)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [repeated 2x across cluster] (Ray deduplicates logs by default. Set RAY_DEDUP_LOGS=0 to disable log deduplication, or see https://docs.ray.io/en/master/ray-observability/user-guides/configure-logging.html#log-deduplication for more options.)
(MultiAgentEnvRunner pid=5098)   File "/Users/mcneill/miniforge3/envs/ray/lib/python3.12/site-packages/ray/rllib/env/multi_agent_env_runner.py", line 106, in __init__
(MultiAgentEnvRunner pid=5098)     self.make_env()
(MultiAgentEnvRunner pid=5098)   File "/Users/mcneill/miniforge3/envs/ray/lib/python3.12/site-packages/ray/rllib/env/multi_agent_env_runner.py", line 820, in make_env
(MultiAgentEnvRunner pid=5098)     assert isinstance(self.env.unwrapped, MultiAgentEnv), (
(MultiAgentEnvRunner pid=5098)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(MultiAgentEnvRunner pid=5098) AssertionError: ERROR: When using the `MultiAgentEnvRunner` the environment needs to inherit from `ray.rllib.env.multi_agent_env.MultiAgentEnv`.

Process finished with exit code 1

The assertion is incorrect:I am inheriting from MultiAgentEnv. The incorrect assertion is assert isinstance(self.env.unwrapped, MultiAgentEnv) in MultiAgentEnvRunner. But it looks like this was addressed in “[RLlib] Fix MultiAgentEnvRunner env check bug. (#50891)”

@wpm In regard to the truncation, I would do so for completeness and to be sure that no future code changes do break your code, .e.g. any checks on truncated["__all__"]

The second error you encounter - sorry for that - should have been fixed in the latest Ray version (if you can update). See here

Hello everyone,

I am currently running into another error with the MultiAgentEnvRunner when using my custom env, I thought it might be relvant to this thread:

File "/home/clemente/miniconda3/envs/autops-rl/lib/python3.12/site-packages/ray/rllib/env/multi_agent_env_runner.py", line 107, in __init__
    self.make_env()
  File "/home/clemente/miniconda3/envs/autops-rl/lib/python3.12/site-packages/ray/rllib/env/multi_agent_env_runner.py", line 795, in make_env
    self.env = make_vec(
               ^^^^^^^^^
  File "/home/clemente/miniconda3/envs/autops-rl/lib/python3.12/site-packages/ray/rllib/env/vector/registration.py", line 69, in make_vec
    env = SyncVectorMultiAgentEnv(
          ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/clemente/miniconda3/envs/autops-rl/lib/python3.12/site-packages/ray/rllib/env/vector/sync_vector_multi_agent_env.py", line 37, in __init__
    self.single_action_spaces = self.envs[0].unwrapped.action_spaces or dict(
                                                                        ^^^^^
TypeError: 'NoneType' object is not iterable

I tried defining get_observation_space and get_action_space, defining the obs and act spaces in several different ways but no luck.

At the moment my spaces are defined in my class FSS_env(MultiAgentEnv) and they look like this:

self._action_space = spaces.Discrete(3)

        self._observation_space = spaces.Dict({
                "observer_satellites": spaces.Box(low=-np.inf, high=np.inf, shape=(self.num_observers, len(self.orbital_params_order))),
                "band": spaces.Box(low=1, high=5, shape=(1,), dtype=np.int8),
                "target_satellites": spaces.Box(low=-np.inf, high=np.inf, shape=(self.num_targets, len(self.orbital_params_order_targets))),
                "availability": spaces.MultiBinary(1),
                "battery": spaces.Box(low=0, high=1, shape=(self.num_observers,)),
                "storage": spaces.Box(low=0, high=1, shape=(self.num_observers,)),
                "observation_status": spaces.Box(low=0, high=3, shape=(self.num_targets,)),
                "pointing_accuracy": spaces.Box(low=-0, high=1, shape=(self.num_observers, self.num_targets)),
                "communication_status": spaces.Box(low=0, high=1, shape=(self.num_observers,), dtype=np.int8),
                "communication_ability": spaces.MultiBinary(self.num_observers)
            })
        
        self.action_spaces = {
            agent_id: self._action_space
            for agent_id in self.possible_agents
        }
        self.observation_spaces = {
            agent_id: self._observation_space
            for agent_id in self.possible_agents
        }
        
        self.action_space = self._action_space
        self.observation_space = self._observation_space
        self.single_observation_spaces = self._observation_space
        self.single_action_spaces = self._action_space

Also I tried to FlattenObservations directly and env_to_module_connector through the env config:

 .env_runners(
            num_env_runners=args.num_env_runners,
            num_envs_per_env_runner=args.num_envs_per_runner,
            num_cpus_per_env_runner=args.num_cpus_per_runner,
            num_gpus_per_env_runner=args.num_gpus_per_runner,
            explore=True,
            env_to_module_connector=_env_to_module_pipeline,
        )

def _env_to_module_pipeline(env):
    return FlattenObservations(
        input_observation_space=env.observation_space,
        input_action_space=env.action_space,
        multi_agent=True
    )

I am running out of ideas to try. I saw that Simon and Sven have to check the definition of get_observation_space(agent) and get_action_space(agent) in sync_vector_multi_agent_env.py, so I don’t know if that will solve my issue. I am happy to provide more info or help with anything you need from my side.

Thanks a lot in advance and kind regards,
Clemente