Multi-agent Replay Buffer in DQN fails to run

1. Severity of the issue:
High: Completely blocks me.

2. Environment:

  • Ray version: 2.47.1
  • Python version: 3.9.23
  • OS: Linux

3. What happened vs. what you expected:

  • Expected: Multi-agent DQN trainning can run
  • Actual: Multi-agent DQN trainning does not run

Hello,

I believe I have experienced a bug in the DQN algorithm when MultiAgentReplayBuffer is used with the new API. This post faces the exact same problem as me, but after some digging into the code, I would like to provide some further information and hope someone could give me more guidance.

The problem seems to be in ray/rllib/algorithms/dqn/dqn.py. In line 654, it creates the episodes variable through synchronous_parallel_sample(), which, in our case, is a list of Multi-Agent Episodes (MAEps). In the following lines, it calls self.local_replay_buffer.add(episodes). Here, the add() method expects the input to be of type SampleBatchType, which could be a union of ["SampleBatch", "MultiAgentBatch", Dict[str, Any]]. It then calls .as_multi_agent() on the input, which causes the error.

Here, the problem arises because a list of MAEps is not one of the acceptable types, which does not have the .as_multi_agent() method.

However, looking at the documentations in the synchronous_parallel_sample source code, it does seem like the old API returns the correct type, which is the SampleBatch object. The list of MAEps only gets returned when the _uses_new_env_runners is set to true, otherwise we get a list of SampleBatch intead. However, I cannot successfully switch to the old API via config, and was wondering if there are some example code that I could follow? If not, is there any other quick fixes to the problem that I could use?

The error I got is:

  File "/home/yw180/sequential_social_dilemma_games/run_scripts/DQN_test.py", line 68, in <module>
    result = algo.train()
  File "/home/yw180/.conda/envs/ssd_marl/lib/python3.9/site-packages/ray/tune/trainable/trainable.py", line 330, in train
    raise skipped from exception_cause(skipped)
  File "/home/yw180/.conda/envs/ssd_marl/lib/python3.9/site-packages/ray/tune/trainable/trainable.py", line 327, in train
    result = self.step()
  File "/home/yw180/.conda/envs/ssd_marl/lib/python3.9/site-packages/ray/rllib/algorithms/algorithm.py", line 1035, in step
    train_results, train_iter_ctx = self._run_one_training_iteration()
  File "/home/yw180/.conda/envs/ssd_marl/lib/python3.9/site-packages/ray/rllib/algorithms/algorithm.py", line 3309, in _run_one_training_iteration
    training_step_return_value = self.training_step()
  File "/home/yw180/.conda/envs/ssd_marl/lib/python3.9/site-packages/ray/rllib/algorithms/dqn/dqn.py", line 644, in training_step
    return self._training_step_new_api_stack()
  File "/home/yw180/.conda/envs/ssd_marl/lib/python3.9/site-packages/ray/rllib/algorithms/dqn/dqn.py", line 669, in _training_step_new_api_stack
    self.local_replay_buffer.add(episodes)
  File "/home/yw180/.conda/envs/ssd_marl/lib/python3.9/site-packages/ray/rllib/utils/replay_buffers/multi_agent_replay_buffer.py", line 224, in add
    batch = batch.as_multi_agent()
AttributeError: 'list' object has no attribute 'as_multi_agent'

Here is a simple multi-agent cartpole example using DQN that would reproduce this exact error:


from ray.rllib.examples.envs.classes.multi_agent import MultiAgentCartPole
from ray.rllib.utils.test_utils import (
    add_rllib_example_script_args,
)
from ray.tune.registry import register_env
from ray.rllib.algorithms.dqn import DQNConfig
parser = add_rllib_example_script_args(
    default_iters=200,
    default_timesteps=100000,
    default_reward=600.0,
)
# TODO (sven): This arg is currently ignored (hard-set to 2).
parser.add_argument("--num-policies", type=int, default=2)


if __name__ == "__main__":
    args = parser.parse_args()

    # Register our environment with tune.
    register_env(
        "env",
        lambda _: MultiAgentCartPole(config={"num_agents": 2}),
    )
    config = (
            DQNConfig()
            .environment("env")    
            .env_runners(num_env_runners=1,
                        explore=True,
                        )

            .framework("torch")
            .multi_agent(
                policies={f"p{i}" for i in range(2)},
                policy_mapping_fn=lambda aid, *a, **kw: f"p{aid}",
            )
            .training(
                replay_buffer_config={
                "type": "MultiAgentReplayBuffer", 
                },
                lr=0.001,
                train_batch_size=100,
                gamma=0.99,
            )
        )

    algo = config.build_algo()
    for i in range(10):
        result = algo.train()

Thanks!