I am conducting experiments related to multi-agent scenarios, and I aim to develop a version with a centralized critic. I am referencing the rllib/examples/centralized_critic_2.py example.
My goal is to create an actor and a critic with the same architecture as one would when using ModelCatalog.get_model_v2 and passing the standard arguments to create a model. Specifically, I want to utilize the parameter ‘use_attention’ = True.
Below is the code for the model with the centralized critic:
import numpy as np
import torch
import torch.nn as nn
from gymnasium import spaces
from ray.rllib.models.torch.attention_net import AttentionWrapper, GTrXLNet
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.models.catalog import ModelCatalog, MODEL_DEFAULTS
class YetAnotherTorchCentralizedCriticModel(TorchModelV2, nn.Module):
"""Multi-agent model that implements a centralized value function.
It assumes the observation is a dict with 'own_obs' and 'opponent_obs', the
former of which can be used for computing actions (i.e., decentralized
execution), and the latter for optimization (i.e., centralized learning).
This model has two parts:
- An action model that looks at just 'own_obs' to compute actions
- A value model that also looks at the 'opponent_obs' / 'opponent_action'
to compute the value (it does this by using the 'obs_flat' tensor).
"""
def __init__(self, obs_space, action_space, num_outputs, model_config, name):
TorchModelV2.__init__(
self, obs_space, action_space, num_outputs, model_config, name
)
nn.Module.__init__(self)
config = MODEL_DEFAULTS.copy()
config.update(model_config)
model_config = config.copy()
model_config.pop("custom_model")
obs_space = spaces.Dict()
obs_space["actor_obs"] = spaces.Box(low=-np.inf, high=np.inf,
shape=(8, 12,),
dtype=np.float64)
obs_space["critic_obs"] = spaces.Box(low=-np.inf, high=np.inf,
shape=(8, 45,),
dtype=np.float64)
self.action_model = ModelCatalog.get_model_v2(obs_space=obs_space["actor_obs"],
action_space=action_space,
num_outputs=1,
model_config=model_config,
framework="torch",)
self.value_model = ModelCatalog.get_model_v2(obs_space=obs_space["critic_obs"],
action_space=action_space,
num_outputs=1,
model_config=model_config,
framework="torch",)
self._model_in = None
def forward(self, input_dict, state, seq_lens):
# Store model-input for possible `value_function()` call.
seq_lens = torch.ones(len(input_dict))
input_dict_temp = input_dict.copy()
input_dict_temp.pop("obs_flat")
input_dict_temp["obs"] = input_dict["obs"]["critic_obs"]
self._model_in = [input_dict_temp, state, seq_lens]
input_dict_temp["obs"] = input_dict["obs"]["actor_obs"]
return self.action_model(input_dict_temp, state, seq_lens)
def value_function(self):
_, _ = self.value_model(
self._model_in[0], self._model_in[1], self._model_in[2]
)
value_out = self.value_model.value_function()
return value_out
And here is how I instantiate the PPO:
algo = (
PPOConfig()
.environment(env_id, env_config=config["environment"])
.resources(num_gpus=config["agent"]["num_gpus"], )
.rollouts(num_rollout_workers=1)
# .training(**config["training"])
.training(model={"custom_model": "cc_model",
"max_seq_len": 8,
"_disable_preprocessor_api": False,
"use_attention": True,
"attention_num_heads": 8,
})
.experimental(_enable_new_api_stack=False)
.multi_agent(
policies={"shared_policy"},
policy_mapping_fn=(lambda agent_id, episode, worker, **kwargs: "shared_policy"),
# observation_fn=central_critic_observer,
# policies={f'household_{i}' for i in range(10)},
# policy_mapping_fn=(lambda agent_id, *args, **kwargs: agent_id),
)
.build(logger_creator=custom_log_creator(root_dir, config['agent']['algorithm']))
)
I get the following error:
Traceback (most recent call last):
File "/Applications/PyCharm.app/Contents/plugins/python/helpers/pydev/pydevconsole.py", line 364, in runcode
coro = func()
File "<input>", line 1, in <module>
File "/Applications/PyCharm.app/Contents/plugins/python/helpers/pydev/_pydev_bundle/pydev_umd.py", line 197, in runfile
pydev_imports.execfile(filename, global_vars, local_vars) # execute the script
File "/Applications/PyCharm.app/Contents/plugins/python/helpers/pydev/_pydev_imps/_pydev_execfile.py", line 18, in execfile
exec(compile(contents+"\n", file, 'exec'), glob, loc)
File "/Users/matteotortora/Matteo Tortora/Universita/Progetti/DRL - Energy Community/example.py", line 286, in <module>
algo = run_training(env_id, config, root_dir)
File "/Users/matteotortora/Matteo Tortora/Universita/Progetti/DRL - Energy Community/example.py", line 242, in run_training
PPOConfig()
File "/Users/matteotortora/ENTER/envs/drl-ec/lib/python3.9/site-packages/ray/rllib/algorithms/algorithm_config.py", line 1137, in build
return algo_class(
File "/Users/matteotortora/ENTER/envs/drl-ec/lib/python3.9/site-packages/ray/rllib/algorithms/algorithm.py", line 516, in __init__
super().__init__(
File "/Users/matteotortora/ENTER/envs/drl-ec/lib/python3.9/site-packages/ray/tune/trainable/trainable.py", line 161, in __init__
self.setup(copy.deepcopy(self.config))
File "/Users/matteotortora/ENTER/envs/drl-ec/lib/python3.9/site-packages/ray/rllib/algorithms/algorithm.py", line 638, in setup
self.workers = WorkerSet(
File "/Users/matteotortora/ENTER/envs/drl-ec/lib/python3.9/site-packages/ray/rllib/evaluation/worker_set.py", line 181, in __init__
raise e.args[0].args[2]
ValueError: Expected flattened obs shape of [..., 456], got torch.Size([32, 1])
(RolloutWorker pid=90863) 2024-01-18 13:00:20,585 WARNING deprecation.py:50 -- DeprecationWarning: `ray.rllib.models.torch.attention_net.AttentionWrapper` has been deprecated. This will raise an error in the future!
(RolloutWorker pid=90863) Exception raised in creation task: The actor died because of an error raised in its creation task, ray::RolloutWorker.__init__() (pid=90863, ip=127.0.0.1, actor_id=cfe8f4d9a45316a6689662b901000000, repr=<ray.rllib.evaluation.rollout_worker.RolloutWorker object at 0x16a79bbe0>)
(RolloutWorker pid=90863) File "/Users/matteotortora/ENTER/envs/drl-ec/lib/python3.9/site-packages/ray/rllib/evaluation/rollout_worker.py", line 535, in __init__
(RolloutWorker pid=90863) self._update_policy_map(policy_dict=self.policy_dict)
(RolloutWorker pid=90863) File "/Users/matteotortora/ENTER/envs/drl-ec/lib/python3.9/site-packages/ray/rllib/evaluation/rollout_worker.py", line 1746, in _update_policy_map
(RolloutWorker pid=90863) self._build_policy_map(
(RolloutWorker pid=90863) File "/Users/matteotortora/ENTER/envs/drl-ec/lib/python3.9/site-packages/ray/rllib/evaluation/rollout_worker.py", line 1857, in _build_policy_map
(RolloutWorker pid=90863) new_policy = create_policy_for_framework(
(RolloutWorker pid=90863) File "/Users/matteotortora/ENTER/envs/drl-ec/lib/python3.9/site-packages/ray/rllib/utils/policy.py", line 141, in create_policy_for_framework
(RolloutWorker pid=90863) return policy_class(observation_space, action_space, merged_config)
(RolloutWorker pid=90863) File "/Users/matteotortora/ENTER/envs/drl-ec/lib/python3.9/site-packages/ray/rllib/algorithms/ppo/ppo_torch_policy.py", line 64, in __init__
(RolloutWorker pid=90863) self._initialize_loss_from_dummy_batch()
(RolloutWorker pid=90863) File "/Users/matteotortora/ENTER/envs/drl-ec/lib/python3.9/site-packages/ray/rllib/policy/policy.py", line 1430, in _initialize_loss_from_dummy_batch
(RolloutWorker pid=90863) actions, state_outs, extra_outs = self.compute_actions_from_input_dict(
(RolloutWorker pid=90863) File "/Users/matteotortora/ENTER/envs/drl-ec/lib/python3.9/site-packages/ray/rllib/policy/torch_policy_v2.py", line 572, in compute_actions_from_input_dict
(RolloutWorker pid=90863) return self._compute_action_helper(
(RolloutWorker pid=90863) File "/Users/matteotortora/ENTER/envs/drl-ec/lib/python3.9/site-packages/ray/rllib/utils/threading.py", line 24, in wrapper
(RolloutWorker pid=90863) return func(self, *a, **k)
(RolloutWorker pid=90863) File "/Users/matteotortora/ENTER/envs/drl-ec/lib/python3.9/site-packages/ray/rllib/policy/torch_policy_v2.py", line 1293, in _compute_action_helper
(RolloutWorker pid=90863) dist_inputs, state_out = self.model(input_dict, state_batches, seq_lens)
(RolloutWorker pid=90863) File "/Users/matteotortora/ENTER/envs/drl-ec/lib/python3.9/site-packages/ray/rllib/models/modelv2.py", line 263, in __call__
(RolloutWorker pid=90863) res = self.forward(restored, state or [], seq_lens)
(RolloutWorker pid=90863) File "/Users/matteotortora/ENTER/envs/drl-ec/lib/python3.9/site-packages/ray/rllib/models/torch/attention_net.py", line 442, in forward
(RolloutWorker pid=90863) self._features, memory_outs = self.gtrxl(input_dict, state, seq_lens)
(RolloutWorker pid=90863) File "/Users/matteotortora/ENTER/envs/drl-ec/lib/python3.9/site-packages/ray/rllib/models/modelv2.py", line 251, in __call__
(RolloutWorker pid=90863) restored["obs"] = restore_original_dimensions(
(RolloutWorker pid=90863) File "/Users/matteotortora/ENTER/envs/drl-ec/lib/python3.9/site-packages/ray/rllib/models/modelv2.py", line 417, in restore_original_dimensions
(RolloutWorker pid=90863) return _unpack_obs(obs, original_space, tensorlib=tensorlib)
(RolloutWorker pid=90863) File "/Users/matteotortora/ENTER/envs/drl-ec/lib/python3.9/site-packages/ray/rllib/models/modelv2.py", line 451, in _unpack_obs
(RolloutWorker pid=90863) raise ValueError(
(RolloutWorker pid=90863) ValueError: Expected flattened obs shape of [..., 456], got torch.Size([32, 1])
What am I doing wrong? any tips?