Help with ppo config in multiagent env with complex observations

Hi everyone!

I’m experimenting with the new RLlib API stack (enable_rl_module_and_learner=True, enable_env_runner_and_connector_v2=True) and trying to train a multi-agent PPO setup. I use a custom environment with two types of agents: "pass" and "gruz". Here’s a minimal example of my config:

from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.core.rl_module.multi_rl_module import MultiRLModuleSpec
from ray.rllib.core.rl_module.rl_module import RLModuleSpec
from ray.rllib.connectors.env_to_module import FlattenObservations
from ray.tune.registry import register_env
from env.dispatch_env_cubes_ray import DispatchEnvCubesRay

register_env('cubes_ray', lambda _: DispatchEnvCubesRay(render_mode='silent'))

config = (
    PPOConfig()
    .environment("cubes_ray")
    .multi_agent(
        policies={"pass", "gruz"},
        policy_mapping_fn=lambda aid, *args, **kwargs: "pass" if aid.startswith("pass") else "gruz",
        policies_to_train=["pass", "gruz"]
    )
    .rl_module(
        rl_module_spec=MultiRLModuleSpec(
            rl_module_specs={
                "pass": RLModuleSpec(),
                "gruz": RLModuleSpec()
            }
        )
    )
    .api_stack(
        enable_rl_module_and_learner=True,
        enable_env_runner_and_connector_v2=True,
    )
    .env_runners(
        env_to_module_connector=lambda env: FlattenObservations(multi_agent=True),
        num_envs_per_env_runner=1
    )
    .training(lr=0.0003)
    .resources(num_gpus=1)
)

algo = config.build_algo()

for i in range(5):
    try:
        result = algo.train()
        print(f"Iteration {i+1} completed successfully")
    except Exception as e:
        print(f"Error during iteration {i+1}: {str(e)}")
        break

Console Output

2025-04-11 08:49:27,739 WARNING algorithm_config.py:4704 -- You are running PPO on the new API stack! This is the new default behavior for this algorithm. If you don't want to use the new API stack, set `config.api_stack(enable_rl_module_and_learner=False,enable_env_runner_and_connector_v2=False)`. For a detailed migration guide, see here: https://docs.ray.io/en/master/rllib/new-api-stack-migration-guide.html
/virtual-dispatcher-environment-model/.venv/lib/python3.12/site-packages/ray/rllib/algorithms/algorithm.py:512: RayDeprecationWarning: This API is deprecated and may be removed in future Ray releases. You could suppress this warning by setting env variable PYTHONWARNINGS="ignore::DeprecationWarning"
`UnifiedLogger` will be removed in Ray 2.7.
  return UnifiedLogger(config, logdir, loggers=None)
/virtual-dispatcher-environment-model/.venv/lib/python3.12/site-packages/ray/tune/logger/unified.py:53: RayDeprecationWarning: This API is deprecated and may be removed in future Ray releases. You could suppress this warning by setting env variable PYTHONWARNINGS="ignore::DeprecationWarning"
The `JsonLogger interface is deprecated in favor of the `ray.tune.json.JsonLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
/virtual-dispatcher-environment-model/.venv/lib/python3.12/site-packages/ray/tune/logger/unified.py:53: RayDeprecationWarning: This API is deprecated and may be removed in future Ray releases. You could suppress this warning by setting env variable PYTHONWARNINGS="ignore::DeprecationWarning"
The `CSVLogger interface is deprecated in favor of the `ray.tune.csv.CSVLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
/virtual-dispatcher-environment-model/.venv/lib/python3.12/site-packages/ray/tune/logger/unified.py:53: RayDeprecationWarning: This API is deprecated and may be removed in future Ray releases. You could suppress this warning by setting env variable PYTHONWARNINGS="ignore::DeprecationWarning"
The `TBXLogger interface is deprecated in favor of the `ray.tune.tensorboardx.TBXLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
2025-04-11 08:49:30,725 WARNING services.py:2070 -- WARNING: The object store is using /tmp instead of /dev/shm because /dev/shm has only 67096576 bytes available. This will harm performance! You may be able to free up space by deleting files in /dev/shm. If you are inside a Docker container, you can increase /dev/shm size by passing '--shm-size=10.24gb' to 'docker run' (or add it to the run_options list in a Ray cluster config). Make sure to set this to more than 30% of available RAM.
2025-04-11 08:49:31,929 INFO worker.py:1852 -- Started a local Ray instance.
(MultiAgentEnvRunner pid=2588717) 2025-04-11 08:49:39,671       WARNING deprecation.py:50 -- DeprecationWarning: `RLModule(config=[RLModuleConfig object])` has been deprecated. Use `RLModule(observation_space=.., action_space=.., inference_only=.., model_config=.., catalog_class=..)` instead. This will raise an error in the future!
2025-04-11 08:49:39,906 WARNING deprecation.py:50 -- DeprecationWarning: `RLModule(config=[RLModuleConfig object])` has been deprecated. Use `RLModule(observation_space=.., action_space=.., inference_only=.., model_config=.., catalog_class=..)` instead. This will raise an error in the future!
2025-04-11 08:49:40,850 INFO trainable.py:160 -- Trainable.setup took 13.099 seconds. If your trainable is slow to initialize, consider setting reuse_actors=True to reduce actor creation overheads.
2025-04-11 08:49:40,850 WARNING util.py:61 -- Install gputil for GPU system monitoring.
Iteration 1 completed successfully
Error during iteration 2: 'learner'
(MultiAgentEnvRunner pid=2588716) 2025-04-11 08:49:39,750       WARNING deprecation.py:50 -- DeprecationWarning: `RLModule(config=[RLModuleConfig object])` has been deprecated. Use `RLModule(observation_space=.., action_space=.., inference_only=.., model_config=.., catalog_class=..)` instead. This will raise an error in the future!

As you can see, the first iteration completes successfully, but the second one crashes:

Iteration 1 completed successfully  
Error during iteration 2: 'learner'

The full logs also include warnings like:

DeprecationWarning: `RLModule(config=...)` has been deprecated.

So I’m guessing something fails during the actual learning step (maybe gradient computation).

I suspect this may be related to issue #51333. My environment has fairly complex multi-agent observations (spaces.Box(0, 2, (4, 2, 10), int)), so I use FlattenObservations(multi_agent=True) to flatten them. However, FlattenObservations(multi_agent=True) actually returns a Dict keyed by agent ID, and I wonder if that breaks something internally during the learner step.

Questions

  1. Do I need to flatten my observations twice? Once per agent and then again at the top level?
  2. Should I instead apply FlattenObservations(multi_agent=False) inside each agent’s RLModule? If yes — how?
  3. Is my usage of MultiRLModuleSpec and RLModuleSpec() sufficient, or am I missing required fields like observation_space, etc.?

Let me know if a full code sample would help — I’d be happy to share more.

Thanks in advance!