Multi-Agent with Centralized Critic using an Attention Model

I am conducting experiments related to multi-agent scenarios, and I aim to develop a version with a centralized critic. I am referencing the rllib/examples/centralized_critic_2.py example.

My goal is to create an actor and a critic with the same architecture as one would when using ModelCatalog.get_model_v2 and passing the standard arguments to create a model. Specifically, I want to utilize the parameter ‘use_attention’ = True.

Below is the code for the model with the centralized critic:

import numpy as np
import torch
import torch.nn as nn
from gymnasium import spaces
from ray.rllib.models.torch.attention_net import AttentionWrapper, GTrXLNet
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.models.catalog import ModelCatalog, MODEL_DEFAULTS


class YetAnotherTorchCentralizedCriticModel(TorchModelV2, nn.Module):
    """Multi-agent model that implements a centralized value function.

    It assumes the observation is a dict with 'own_obs' and 'opponent_obs', the
    former of which can be used for computing actions (i.e., decentralized
    execution), and the latter for optimization (i.e., centralized learning).

    This model has two parts:
    - An action model that looks at just 'own_obs' to compute actions
    - A value model that also looks at the 'opponent_obs' / 'opponent_action'
      to compute the value (it does this by using the 'obs_flat' tensor).
    """

    def __init__(self, obs_space, action_space, num_outputs, model_config, name):
        TorchModelV2.__init__(
            self, obs_space, action_space, num_outputs, model_config, name
        )
        nn.Module.__init__(self)

        config = MODEL_DEFAULTS.copy()
        config.update(model_config)
        model_config = config.copy()
        model_config.pop("custom_model")

        obs_space = spaces.Dict()

        obs_space["actor_obs"] = spaces.Box(low=-np.inf, high=np.inf,
                                            shape=(8, 12,),
                                            dtype=np.float64)

        obs_space["critic_obs"] = spaces.Box(low=-np.inf, high=np.inf,
                                             shape=(8, 45,),
                                             dtype=np.float64)

        self.action_model = ModelCatalog.get_model_v2(obs_space=obs_space["actor_obs"],
                                                      action_space=action_space,
                                                      num_outputs=1,
                                                      model_config=model_config,
                                                      framework="torch",)

        self.value_model = ModelCatalog.get_model_v2(obs_space=obs_space["critic_obs"],
                                                     action_space=action_space,
                                                     num_outputs=1,
                                                     model_config=model_config,
                                                     framework="torch",)

        self._model_in = None

    def forward(self, input_dict, state, seq_lens):
        # Store model-input for possible `value_function()` call.
        seq_lens = torch.ones(len(input_dict))
        input_dict_temp = input_dict.copy()
        input_dict_temp.pop("obs_flat")

        input_dict_temp["obs"] = input_dict["obs"]["critic_obs"]
        self._model_in = [input_dict_temp, state, seq_lens]

        input_dict_temp["obs"] = input_dict["obs"]["actor_obs"]
        return self.action_model(input_dict_temp, state, seq_lens)

    def value_function(self):
        _, _ = self.value_model(
            self._model_in[0], self._model_in[1], self._model_in[2]
        )

        value_out = self.value_model.value_function()
        return value_out

And here is how I instantiate the PPO:

    algo = (
        PPOConfig()
        .environment(env_id, env_config=config["environment"])
        .resources(num_gpus=config["agent"]["num_gpus"], )
        .rollouts(num_rollout_workers=1)
        # .training(**config["training"])
        .training(model={"custom_model": "cc_model",
                         "max_seq_len": 8,
                         "_disable_preprocessor_api": False,
                         "use_attention": True,
                         "attention_num_heads": 8,
                         })
        .experimental(_enable_new_api_stack=False)
        .multi_agent(
            policies={"shared_policy"},
            policy_mapping_fn=(lambda agent_id, episode, worker, **kwargs: "shared_policy"),
            # observation_fn=central_critic_observer,
            # policies={f'household_{i}' for i in range(10)},
            # policy_mapping_fn=(lambda agent_id, *args, **kwargs: agent_id),
        )
        .build(logger_creator=custom_log_creator(root_dir, config['agent']['algorithm']))
    )

I get the following error:

Traceback (most recent call last):
  File "/Applications/PyCharm.app/Contents/plugins/python/helpers/pydev/pydevconsole.py", line 364, in runcode
    coro = func()
  File "<input>", line 1, in <module>
  File "/Applications/PyCharm.app/Contents/plugins/python/helpers/pydev/_pydev_bundle/pydev_umd.py", line 197, in runfile
    pydev_imports.execfile(filename, global_vars, local_vars)  # execute the script
  File "/Applications/PyCharm.app/Contents/plugins/python/helpers/pydev/_pydev_imps/_pydev_execfile.py", line 18, in execfile
    exec(compile(contents+"\n", file, 'exec'), glob, loc)
  File "/Users/matteotortora/Matteo Tortora/Universita/Progetti/DRL - Energy Community/example.py", line 286, in <module>
    algo = run_training(env_id, config, root_dir)
  File "/Users/matteotortora/Matteo Tortora/Universita/Progetti/DRL - Energy Community/example.py", line 242, in run_training
    PPOConfig()
  File "/Users/matteotortora/ENTER/envs/drl-ec/lib/python3.9/site-packages/ray/rllib/algorithms/algorithm_config.py", line 1137, in build
    return algo_class(
  File "/Users/matteotortora/ENTER/envs/drl-ec/lib/python3.9/site-packages/ray/rllib/algorithms/algorithm.py", line 516, in __init__
    super().__init__(
  File "/Users/matteotortora/ENTER/envs/drl-ec/lib/python3.9/site-packages/ray/tune/trainable/trainable.py", line 161, in __init__
    self.setup(copy.deepcopy(self.config))
  File "/Users/matteotortora/ENTER/envs/drl-ec/lib/python3.9/site-packages/ray/rllib/algorithms/algorithm.py", line 638, in setup
    self.workers = WorkerSet(
  File "/Users/matteotortora/ENTER/envs/drl-ec/lib/python3.9/site-packages/ray/rllib/evaluation/worker_set.py", line 181, in __init__
    raise e.args[0].args[2]
ValueError: Expected flattened obs shape of [..., 456], got torch.Size([32, 1])
(RolloutWorker pid=90863) 2024-01-18 13:00:20,585	WARNING deprecation.py:50 -- DeprecationWarning: `ray.rllib.models.torch.attention_net.AttentionWrapper` has been deprecated. This will raise an error in the future!
(RolloutWorker pid=90863) Exception raised in creation task: The actor died because of an error raised in its creation task, ray::RolloutWorker.__init__() (pid=90863, ip=127.0.0.1, actor_id=cfe8f4d9a45316a6689662b901000000, repr=<ray.rllib.evaluation.rollout_worker.RolloutWorker object at 0x16a79bbe0>)
(RolloutWorker pid=90863)   File "/Users/matteotortora/ENTER/envs/drl-ec/lib/python3.9/site-packages/ray/rllib/evaluation/rollout_worker.py", line 535, in __init__
(RolloutWorker pid=90863)     self._update_policy_map(policy_dict=self.policy_dict)
(RolloutWorker pid=90863)   File "/Users/matteotortora/ENTER/envs/drl-ec/lib/python3.9/site-packages/ray/rllib/evaluation/rollout_worker.py", line 1746, in _update_policy_map
(RolloutWorker pid=90863)     self._build_policy_map(
(RolloutWorker pid=90863)   File "/Users/matteotortora/ENTER/envs/drl-ec/lib/python3.9/site-packages/ray/rllib/evaluation/rollout_worker.py", line 1857, in _build_policy_map
(RolloutWorker pid=90863)     new_policy = create_policy_for_framework(
(RolloutWorker pid=90863)   File "/Users/matteotortora/ENTER/envs/drl-ec/lib/python3.9/site-packages/ray/rllib/utils/policy.py", line 141, in create_policy_for_framework
(RolloutWorker pid=90863)     return policy_class(observation_space, action_space, merged_config)
(RolloutWorker pid=90863)   File "/Users/matteotortora/ENTER/envs/drl-ec/lib/python3.9/site-packages/ray/rllib/algorithms/ppo/ppo_torch_policy.py", line 64, in __init__
(RolloutWorker pid=90863)     self._initialize_loss_from_dummy_batch()
(RolloutWorker pid=90863)   File "/Users/matteotortora/ENTER/envs/drl-ec/lib/python3.9/site-packages/ray/rllib/policy/policy.py", line 1430, in _initialize_loss_from_dummy_batch
(RolloutWorker pid=90863)     actions, state_outs, extra_outs = self.compute_actions_from_input_dict(
(RolloutWorker pid=90863)   File "/Users/matteotortora/ENTER/envs/drl-ec/lib/python3.9/site-packages/ray/rllib/policy/torch_policy_v2.py", line 572, in compute_actions_from_input_dict
(RolloutWorker pid=90863)     return self._compute_action_helper(
(RolloutWorker pid=90863)   File "/Users/matteotortora/ENTER/envs/drl-ec/lib/python3.9/site-packages/ray/rllib/utils/threading.py", line 24, in wrapper
(RolloutWorker pid=90863)     return func(self, *a, **k)
(RolloutWorker pid=90863)   File "/Users/matteotortora/ENTER/envs/drl-ec/lib/python3.9/site-packages/ray/rllib/policy/torch_policy_v2.py", line 1293, in _compute_action_helper
(RolloutWorker pid=90863)     dist_inputs, state_out = self.model(input_dict, state_batches, seq_lens)
(RolloutWorker pid=90863)   File "/Users/matteotortora/ENTER/envs/drl-ec/lib/python3.9/site-packages/ray/rllib/models/modelv2.py", line 263, in __call__
(RolloutWorker pid=90863)     res = self.forward(restored, state or [], seq_lens)
(RolloutWorker pid=90863)   File "/Users/matteotortora/ENTER/envs/drl-ec/lib/python3.9/site-packages/ray/rllib/models/torch/attention_net.py", line 442, in forward
(RolloutWorker pid=90863)     self._features, memory_outs = self.gtrxl(input_dict, state, seq_lens)
(RolloutWorker pid=90863)   File "/Users/matteotortora/ENTER/envs/drl-ec/lib/python3.9/site-packages/ray/rllib/models/modelv2.py", line 251, in __call__
(RolloutWorker pid=90863)     restored["obs"] = restore_original_dimensions(
(RolloutWorker pid=90863)   File "/Users/matteotortora/ENTER/envs/drl-ec/lib/python3.9/site-packages/ray/rllib/models/modelv2.py", line 417, in restore_original_dimensions
(RolloutWorker pid=90863)     return _unpack_obs(obs, original_space, tensorlib=tensorlib)
(RolloutWorker pid=90863)   File "/Users/matteotortora/ENTER/envs/drl-ec/lib/python3.9/site-packages/ray/rllib/models/modelv2.py", line 451, in _unpack_obs
(RolloutWorker pid=90863)     raise ValueError(
(RolloutWorker pid=90863) ValueError: Expected flattened obs shape of [..., 456], got torch.Size([32, 1])

What am I doing wrong? any tips?