1. Severity of the issue:
High: Completely blocks me.
2. Environment:
- Ray version: 2.47.1
- Python version: 3.9.23
- OS: Linux
3. What happened vs. what you expected:
- Expected: Multi-agent DQN trainning can run
- Actual: Multi-agent DQN trainning does not run
Hello,
I believe I have experienced a bug in the DQN algorithm when MultiAgentReplayBuffer is used with the new API. This post faces the exact same problem as me, but after some digging into the code, I would like to provide some further information and hope someone could give me more guidance.
The problem seems to be in ray/rllib/algorithms/dqn/dqn.py
. In line 654, it creates the episodes
variable through synchronous_parallel_sample()
, which, in our case, is a list of Multi-Agent Episodes (MAEps)
. In the following lines, it calls self.local_replay_buffer.add(episodes)
. Here, the add()
method expects the input to be of type SampleBatchType
, which could be a union of ["SampleBatch", "MultiAgentBatch", Dict[str, Any]]
. It then calls .as_multi_agent()
on the input, which causes the error.
Here, the problem arises because a list of MAEps is not one of the acceptable types, which does not have the .as_multi_agent()
method.
However, looking at the documentations in the synchronous_parallel_sample
source code, it does seem like the old API returns the correct type, which is the SampleBatch object. The list of MAEps only gets returned when the _uses_new_env_runners
is set to true, otherwise we get a list of SampleBatch intead. However, I cannot successfully switch to the old API via config, and was wondering if there are some example code that I could follow? If not, is there any other quick fixes to the problem that I could use?
The error I got is:
File "/home/yw180/sequential_social_dilemma_games/run_scripts/DQN_test.py", line 68, in <module>
result = algo.train()
File "/home/yw180/.conda/envs/ssd_marl/lib/python3.9/site-packages/ray/tune/trainable/trainable.py", line 330, in train
raise skipped from exception_cause(skipped)
File "/home/yw180/.conda/envs/ssd_marl/lib/python3.9/site-packages/ray/tune/trainable/trainable.py", line 327, in train
result = self.step()
File "/home/yw180/.conda/envs/ssd_marl/lib/python3.9/site-packages/ray/rllib/algorithms/algorithm.py", line 1035, in step
train_results, train_iter_ctx = self._run_one_training_iteration()
File "/home/yw180/.conda/envs/ssd_marl/lib/python3.9/site-packages/ray/rllib/algorithms/algorithm.py", line 3309, in _run_one_training_iteration
training_step_return_value = self.training_step()
File "/home/yw180/.conda/envs/ssd_marl/lib/python3.9/site-packages/ray/rllib/algorithms/dqn/dqn.py", line 644, in training_step
return self._training_step_new_api_stack()
File "/home/yw180/.conda/envs/ssd_marl/lib/python3.9/site-packages/ray/rllib/algorithms/dqn/dqn.py", line 669, in _training_step_new_api_stack
self.local_replay_buffer.add(episodes)
File "/home/yw180/.conda/envs/ssd_marl/lib/python3.9/site-packages/ray/rllib/utils/replay_buffers/multi_agent_replay_buffer.py", line 224, in add
batch = batch.as_multi_agent()
AttributeError: 'list' object has no attribute 'as_multi_agent'
Here is a simple multi-agent cartpole example using DQN that would reproduce this exact error:
from ray.rllib.examples.envs.classes.multi_agent import MultiAgentCartPole
from ray.rllib.utils.test_utils import (
add_rllib_example_script_args,
)
from ray.tune.registry import register_env
from ray.rllib.algorithms.dqn import DQNConfig
parser = add_rllib_example_script_args(
default_iters=200,
default_timesteps=100000,
default_reward=600.0,
)
# TODO (sven): This arg is currently ignored (hard-set to 2).
parser.add_argument("--num-policies", type=int, default=2)
if __name__ == "__main__":
args = parser.parse_args()
# Register our environment with tune.
register_env(
"env",
lambda _: MultiAgentCartPole(config={"num_agents": 2}),
)
config = (
DQNConfig()
.environment("env")
.env_runners(num_env_runners=1,
explore=True,
)
.framework("torch")
.multi_agent(
policies={f"p{i}" for i in range(2)},
policy_mapping_fn=lambda aid, *a, **kw: f"p{aid}",
)
.training(
replay_buffer_config={
"type": "MultiAgentReplayBuffer",
},
lr=0.001,
train_batch_size=100,
gamma=0.99,
)
)
algo = config.build_algo()
for i in range(10):
result = algo.train()
Thanks!