DQN MultiAgentReplayBuffer not working

1. Severity of the issue: (select one)
High: Completely blocks me.

2. Environment:

  • Ray version: 2.53.0
  • Python version: 3.12
  • OS: Windows 11

3. What happened vs. what you expected:

The MultiAgentReplayBuffer is not working and seems to have not been working for quite some time. I saw some other posts on here that have the same issue, but never found a solution.

        config = (
            DQNConfig()
            .framework("torch")
            .environment("sumo_marl", env_config=env_config)
            .env_runners(
                num_env_runners=1 if not DEBUG else 0,
                num_envs_per_env_runner=1,
                num_cpus_per_env_runner=3,
                sample_timeout_s=50000,
            )
            .multi_agent(
                policies=["shared"],
                policy_mapping_fn=lambda agent_id, *a, **kw: "shared",
            )
            .learners(
                num_learners=0,
                num_cpus_per_learner=3,
            )
            .training(
                gamma=0.99,
                lr=1e-4,

                # Wichtig: Warmup-Name in deiner Version
                num_steps_sampled_before_learning_starts=20_000,

                # Batchgröße kommt bei dir wie bei MAPPO über kwargs rein
                train_batch_size=4096,

                replay_buffer_config={
                    "_enable_replay_buffer_api": True,
                    "type": "MultiAgentReplayBuffer",
                    "capacity": 300_000,
                },
                target_network_update_freq=8000,

                double_q=True,
                dueling=True,
                n_step=1,

                epsilon=EPS_SCHEDULE,
            )
            .callbacks(MetricsLoggerCallback)

            .debugging(seed=SEED)
        )

This config results in this exception:

2026-01-14 21:48:29,784	ERROR tune_controller.py:1331 -- Trial task failed for trial DQN_sumo_marl_54156_00000
Traceback (most recent call last):
  File "E:\Users\julia\Desktop\Studium\Master\Masterarbeit\TrafficSignalMARL\venv\Lib\site-packages\ray\air\execution\_internal\event_manager.py", line 110, in resolve_future
    result = ray.get(future)
             ^^^^^^^^^^^^^^^
  File "E:\Users\julia\Desktop\Studium\Master\Masterarbeit\TrafficSignalMARL\venv\Lib\site-packages\ray\_private\auto_init_hook.py", line 22, in auto_init_wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "E:\Users\julia\Desktop\Studium\Master\Masterarbeit\TrafficSignalMARL\venv\Lib\site-packages\ray\_private\client_mode_hook.py", line 104, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "E:\Users\julia\Desktop\Studium\Master\Masterarbeit\TrafficSignalMARL\venv\Lib\site-packages\ray\_private\worker.py", line 2967, in get
    values, debugger_breakpoint = worker.get_objects(
                                  ^^^^^^^^^^^^^^^^^^^
  File "E:\Users\julia\Desktop\Studium\Master\Masterarbeit\TrafficSignalMARL\venv\Lib\site-packages\ray\_private\worker.py", line 1015, in get_objects
    raise value.as_instanceof_cause()
ray.exceptions.RayTaskError(AttributeError): ray::DQN.train() (pid=106336, ip=127.0.0.1, actor_id=90a818058acadd367bed5c9401000000, repr=DQN(env=sumo_marl; env-runners=1; learners=0; multi-agent=True))
  File "python\\ray\\_raylet.pyx", line 1722, in ray._raylet.execute_task
  File "python\\ray\\_raylet.pyx", line 1665, in ray._raylet.execute_task.function_executor
  File "E:\Users\julia\Desktop\Studium\Master\Masterarbeit\TrafficSignalMARL\venv\Lib\site-packages\ray\_private\function_manager.py", line 693, in actor_method_executor
    return method(__ray_actor, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "E:\Users\julia\Desktop\Studium\Master\Masterarbeit\TrafficSignalMARL\venv\Lib\site-packages\ray\util\tracing\tracing_helper.py", line 461, in _resume_span
    return method(self, *_args, **_kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "E:\Users\julia\Desktop\Studium\Master\Masterarbeit\TrafficSignalMARL\venv\Lib\site-packages\ray\tune\trainable\trainable.py", line 331, in train
    raise skipped from exception_cause(skipped)
  File "E:\Users\julia\Desktop\Studium\Master\Masterarbeit\TrafficSignalMARL\venv\Lib\site-packages\ray\tune\trainable\trainable.py", line 328, in train
    result = self.step()
             ^^^^^^^^^^^
  File "E:\Users\julia\Desktop\Studium\Master\Masterarbeit\TrafficSignalMARL\venv\Lib\site-packages\ray\util\tracing\tracing_helper.py", line 461, in _resume_span
    return method(self, *_args, **_kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "E:\Users\julia\Desktop\Studium\Master\Masterarbeit\TrafficSignalMARL\venv\Lib\site-packages\ray\rllib\algorithms\algorithm.py", line 1242, in step
    train_results, train_iter_ctx = self._run_one_training_iteration()
                                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "E:\Users\julia\Desktop\Studium\Master\Masterarbeit\TrafficSignalMARL\venv\Lib\site-packages\ray\util\tracing\tracing_helper.py", line 461, in _resume_span
    return method(self, *_args, **_kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "E:\Users\julia\Desktop\Studium\Master\Masterarbeit\TrafficSignalMARL\venv\Lib\site-packages\ray\rllib\algorithms\algorithm.py", line 3666, in _run_one_training_iteration
    training_step_return_value = self.training_step()
                                 ^^^^^^^^^^^^^^^^^^^^
  File "E:\Users\julia\Desktop\Studium\Master\Masterarbeit\TrafficSignalMARL\venv\Lib\site-packages\ray\util\tracing\tracing_helper.py", line 461, in _resume_span
    return method(self, *_args, **_kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "E:\Users\julia\Desktop\Studium\Master\Masterarbeit\TrafficSignalMARL\venv\Lib\site-packages\ray\rllib\algorithms\dqn\dqn.py", line 646, in training_step
    return self._training_step_new_api_stack()
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "E:\Users\julia\Desktop\Studium\Master\Masterarbeit\TrafficSignalMARL\venv\Lib\site-packages\ray\util\tracing\tracing_helper.py", line 461, in _resume_span
    return method(self, *_args, **_kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "E:\Users\julia\Desktop\Studium\Master\Masterarbeit\TrafficSignalMARL\venv\Lib\site-packages\ray\rllib\algorithms\dqn\dqn.py", line 668, in _training_step_new_api_stack
    self.local_replay_buffer.add(episodes)
  File "E:\Users\julia\Desktop\Studium\Master\Masterarbeit\TrafficSignalMARL\venv\Lib\site-packages\ray\rllib\utils\replay_buffers\multi_agent_replay_buffer.py", line 224, in add
    batch = batch.as_multi_agent()
            ^^^^^^^^^^^^^^^^^^^^
AttributeError: 'list' object has no attribute 'as_multi_agent'

This is a known issue: when using DQN with the new API stack and MultiAgentReplayBuffer, the buffer’s add() method expects a SampleBatch, MultiAgentBatch, or dict, but receives a list of MultiAgentEpisodes, which does not have the .as_multi_agent() method, resulting in AttributeError. This is confirmed in recent user reports and is not resolved in Ray 2.53.0. There is currently no official workaround in the documentation, and the problem has persisted across several Ray versions. Switching to the old API stack or using a different replay buffer type may avoid the error, but the new API stack with MultiAgentReplayBuffer is not functioning as expected for multi-agent DQN training at this time. See detailed discussion and error trace in the linked sources.

Would you like more detail on possible workarounds or alternative configurations?

Sources:

Hint: Mention @RunLLM in the post for followups.

How would the config above need to be changed to work with the old API-Stack?

@RunLLM How would the config above need to be changed to work with the old API-Stack?

To use the old API stack, set .api_stack(enable_rl_module_and_learner=False, enable_env_runner_and_connector_v2=False) in your config. Also, for DQN with multi-agent and replay buffer, use "type": "MultiAgentPrioritizedReplayBuffer" in replay_buffer_config (not "MultiAgentReplayBuffer"), as the old stack does not support the episode-based buffer. Example:

config = (
    DQNConfig()
    .framework("torch")
    .environment("sumo_marl", env_config=env_config)
    .api_stack(enable_rl_module_and_learner=False, enable_env_runner_and_connector_v2=False)
    .env_runners(num_env_runners=1, num_envs_per_env_runner=1, num_cpus_per_env_runner=3, sample_timeout_s=50000)
    .multi_agent(
        policies=["shared"],
        policy_mapping_fn=lambda agent_id, *a, **kw: "shared",
    )
    .learners(num_learners=0, num_cpus_per_learner=3)
    .training(
        gamma=0.99,
        lr=1e-4,
        num_steps_sampled_before_learning_starts=20_000,
        train_batch_size=4096,
        replay_buffer_config={
            "type": "MultiAgentPrioritizedReplayBuffer",
            "capacity": 300_000,
        },
        target_network_update_freq=8000,
        double_q=True,
        dueling=True,
        n_step=1,
        epsilon=EPS_SCHEDULE,
    )
    .callbacks(MetricsLoggerCallback)
    .debugging(seed=SEED)
)

This avoids the new stack’s episode buffer and uses the supported prioritized buffer for multi-agent DQN.

Would you like a step-by-step explanation or more details?

Sources:

Hint: Mention @RunLLM in the post for followups.

@RunLLM

This code fails to run:

import random
from pathlib import Path

import numpy as np
import ray
import torch
from ray import tune
from ray.air.integrations.wandb import WandbLoggerCallback
from ray.rllib.algorithms import DQNConfig
from ray.rllib.core.rl_module import MultiRLModuleSpec, RLModuleSpec
from ray.rllib.examples.algorithms.mappo.mappo import MAPPOConfig
from ray.rllib.examples.algorithms.mappo.torch.shared_critic_torch_rl_module import SharedCriticTorchRLModule
from ray.tune.registry import register_env

from callbacks import MetricsLoggerCallback
from config.config import GENERATE_RANDOM_ROUTES, DEBUG, LOG_TO_WANDB
from rl_environment.observation_classes import CameraObservation, NoisyCameraObservation, DefaultObservationClass
from rl_environment.reward_classes import NegativeWaitAndStopGo
from rl_environment.sumo_traffic_env import SumoTrafficEnv

SHARED_CRITIC_ID = "shared_critic"
SEED = 100

TRAINING_CONFIGURATIONS = [
    {
        "RUN_NAME": "MAPPO DOC n_w_a_s_g norm",
        "OBSERVATION_CLASS": DefaultObservationClass,
        "REWARD_FUNCTION": NegativeWaitAndStopGo
    },
]


def env_creator(env_config):
    current_file = Path(__file__)
    project_base = current_file.parent.parent

    # Pfade und maximale Simulationszeit festlegen
    net: Path = project_base / "simulation_files" / "net.net.xml"
    route: Path = project_base / "simulation_files" / "random.rou.xml"
    trip: Path = project_base / "simulation_files" / "random.trips.xml"
    additional: Path = project_base / "simulation_files" / "mytypes.add.xml"

    return SumoTrafficEnv(
        sumo_net_file=net,
        sumo_route_file=route,
        sumo_trip_file=trip,
        sumo_additional_file=additional,
        reward_class=env_config["REWARD_FUNCTION"],
        observation_class=env_config["OBSERVATION_CLASS"],
        show_gui=False,
        simulation_time=600,
        generate_random_routes=GENERATE_RANDOM_ROUTES,
        sumo_simulation_seed=str(SEED)
    )


if __name__ == "__main__":
    for training_configuration in TRAINING_CONFIGURATIONS:
        run_name = training_configuration["RUN_NAME"]
        observation_class = training_configuration["OBSERVATION_CLASS"]
        reward_function = training_configuration["REWARD_FUNCTION"]

        env_config = {
            "OBSERVATION_CLASS": observation_class,
            "REWARD_FUNCTION": reward_function,
        }

        current_file = Path(__file__)
        project_base = current_file.parent.parent

        initial_data_env = env_creator(env_config)

        agent_ids = initial_data_env.agents

        ray.init(local_mode=DEBUG)

        register_env("sumo_marl", env_creator)

        observation_space = initial_data_env.observation_spaces[agent_ids[0]]
        action_space = initial_data_env.action_spaces[agent_ids[0]]

        EPS_SCHEDULE = [
            [0, 1.0],
            [200_000, 0.02],
        ]

        config = (
            DQNConfig()
            .framework("torch")
            .environment("sumo_marl", env_config=env_config)
            .api_stack(enable_rl_module_and_learner=False, enable_env_runner_and_connector_v2=False)
            .env_runners(num_env_runners=1, num_envs_per_env_runner=1, num_cpus_per_env_runner=3,
                         sample_timeout_s=50000)
            .multi_agent(
                policies=["shared"],
                policy_mapping_fn=lambda agent_id, *a, **kw: "shared",
            )
            .learners(num_learners=0, num_cpus_per_learner=3)
            .training(
                gamma=0.99,
                lr=1e-4,
                num_steps_sampled_before_learning_starts=20_000,
                train_batch_size=4096,
                replay_buffer_config={
                    "type": "MultiAgentPrioritizedReplayBuffer",
                    "capacity": 300_000,
                },
                target_network_update_freq=8000,
                double_q=True,
                dueling=True,
                n_step=1,
                epsilon=EPS_SCHEDULE,
            )
            .callbacks(MetricsLoggerCallback)
            .debugging(seed=SEED)
        )


        param_space = config.to_dict()
        # param_space["seed"] = SEED

        # Training starten
        tuner = tune.Tuner(
            config.algo_class,
            param_space=param_space,
            run_config=tune.RunConfig(
                name=run_name,
                stop={"training_iteration": 250},
                verbose=1,
                storage_path=str(project_base / "run_results"),
                callbacks=[WandbLoggerCallback(project="new_env", name=run_name, log_config=True)] if LOG_TO_WANDB else None,
                checkpoint_config=tune.CheckpointConfig(
                    checkpoint_frequency=10, checkpoint_at_end=True
                ),
            ),
        )

        tuner.fit()

        ray.shutdown()


It raises this exception:

(DQN pid=116768) Exception raised in creation task: The actor died because of an error raised in its creation task, ray::DQN.__init__() (pid=116768, ip=127.0.0.1, actor_id=84c3c0bc673b9c96e973b5e101000000, repr=DQN)
(DQN pid=116768)   File "python\\ray\\_raylet.pyx", line 1722, in ray._raylet.execute_task
(DQN pid=116768)   File "python\\ray\\_raylet.pyx", line 1665, in ray._raylet.execute_task.function_executor
(DQN pid=116768)   File "E:\Users\julia\Desktop\Studium\Master\Masterarbeit\TrafficSignalMARL\venv\Lib\site-packages\ray\_private\function_manager.py", line 693, in actor_method_executor
(DQN pid=116768)     return method(__ray_actor, *args, **kwargs)
(DQN pid=116768)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(DQN pid=116768)   File "E:\Users\julia\Desktop\Studium\Master\Masterarbeit\TrafficSignalMARL\venv\Lib\site-packages\ray\util\tracing\tracing_helper.py", line 461, in _resume_span
(DQN pid=116768)     return method(self, *_args, **_kwargs)
(DQN pid=116768)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(DQN pid=116768)   File "E:\Users\julia\Desktop\Studium\Master\Masterarbeit\TrafficSignalMARL\venv\Lib\site-packages\ray\rllib\algorithms\algorithm.py", line 579, in __init__
(DQN pid=116768)     super().__init__(
(DQN pid=116768)   File "E:\Users\julia\Desktop\Studium\Master\Masterarbeit\TrafficSignalMARL\venv\Lib\site-packages\ray\tune\trainable\trainable.py", line 158, in __init__
(DQN pid=116768)     self.setup(copy.deepcopy(self.config))
(DQN pid=116768)   File "E:\Users\julia\Desktop\Studium\Master\Masterarbeit\TrafficSignalMARL\venv\Lib\site-packages\ray\util\tracing\tracing_helper.py", line 461, in _resume_span
(DQN pid=116768)     return method(self, *_args, **_kwargs)
(DQN pid=116768)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(DQN pid=116768)   File "E:\Users\julia\Desktop\Studium\Master\Masterarbeit\TrafficSignalMARL\venv\Lib\site-packages\ray\rllib\algorithms\dqn\dqn.py", line 613, in setup
(DQN pid=116768)     super().setup(config)
(DQN pid=116768)   File "E:\Users\julia\Desktop\Studium\Master\Masterarbeit\TrafficSignalMARL\venv\Lib\site-packages\ray\rllib\algorithms\algorithm.py", line 777, in setup
(DQN pid=116768)     self.local_replay_buffer = self._create_local_replay_buffer_if_necessary(
(DQN pid=116768)                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(DQN pid=116768)   File "E:\Users\julia\Desktop\Studium\Master\Masterarbeit\TrafficSignalMARL\venv\Lib\site-packages\ray\util\tracing\tracing_helper.py", line 461, in _resume_span
(DQN pid=116768)     return method(self, *_args, **_kwargs)
(DQN pid=116768)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(DQN pid=116768)   File "E:\Users\julia\Desktop\Studium\Master\Masterarbeit\TrafficSignalMARL\venv\Lib\site-packages\ray\rllib\algorithms\algorithm.py", line 4334, in _create_local_replay_buffer_if_necessary
(DQN pid=116768)     if "EpisodeReplayBuffer" in config["replay_buffer_config"]["type"]:
(DQN pid=116768)        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(DQN pid=116768) TypeError: argument of type 'ABCMeta' is not iterable

Why?

The error occurs because the value of "type" in your replay_buffer_config is not a string, but a class/type object. The old API stack expects "type" to be a string (e.g., "MultiAgentPrioritizedReplayBuffer"), not a class reference. When the code checks "EpisodeReplayBuffer" in config["replay_buffer_config"]["type"], it fails if "type" is not a string, resulting in TypeError: argument of type 'ABCMeta' is not iterable.

To fix this, ensure that "type" is a string, not a class. For example:

replay_buffer_config={
    "type": "MultiAgentPrioritizedReplayBuffer",  # <-- must be a string
    "capacity": 300_000,
}

This matches the expected usage in the old API stack and avoids the TypeError.

Would you like a step-by-step breakdown or more troubleshooting tips?

Sources:

Hint: Mention @RunLLM in the post for followups.

@RunLLM In the code I provided, type is already a String

You are correct—your code sets "type": "MultiAgentPrioritizedReplayBuffer" as a string. However, the error occurs because somewhere in the config processing, the "type" field is being converted from a string to a class object (likely by RLlib’s internal validate_buffer_config or similar logic). This means that by the time the check if "EpisodeReplayBuffer" in config["replay_buffer_config"]["type"]: runs, "type" is no longer a string, but a class, causing the TypeError.

This is a known RLlib bug: after config validation, "type" is replaced with the class, but later code still assumes it’s a string and tries to use in on it. This issue is discussed in the Ray community and is not due to your code, but to RLlib internals. There is no user-side fix except to downgrade to a version where this bug does not occur, or to patch RLlib locally to check the type before using in (e.g., if isinstance(type_val, str) and "EpisodeReplayBuffer" in type_val:).

Would you like more detail on possible workarounds or a code patch?

Sources:

Hint: Mention @RunLLM in the post for followups.