Achieving individual rewards with agent grouping

How severe does this issue affect your experience of using Ray?

  • High: It blocks me to complete my task.

Hello everyone

I’m currently working on an integration of a paper about communication in multi-agent reinforcement learning. The communication happens within the step, i.e. the agents encode their local observations, create a message and exchange it with the other agents. The encoded observations, together with the aggregated received messages, are then used to determine the next action.
The fact that the messages are created and exchanged within the step means that I have to process all agents collectively. After some research, I stumbled upon agent grouping, which worked well for integrating the communication architecture. But when grouping agents, RLLib treats the agent group as single agent and takes the individual agent rewards and sums them up to one group reward.

I tried to find out at which point I should modify the policy to achieve individual rewards, but since I’m fairly new to RLLib, I struggle to find a proper way.

Below you will find a simplified version of the communication model and training script.

Thanks in advance for your feedback.

import torch
import numpy as np
from torch import nn
import ray
from ray import air
from ray import tune
from ray.tune import CLIReporter
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.models import ModelCatalog
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
from ray.rllib.env import ParallelPettingZooEnv
from pettingzoo.mpe import simple_spread_v2
from gymnasium.spaces import Tuple

algorithm = 'PPO'
num_agents = 2

class GroupModel(TorchModelV2, nn.Module):
    def __init__(self, obs_space, action_space, num_outputs, model_config={}, name='Test Model', **kwargs):
        TorchModelV2.__init__(
            self, obs_space, action_space, num_outputs, model_config, name
        )
        nn.Module.__init__(self)
        num_hidden = 128
        self.model_config = model_config
        # self.num_message = model_config['custom_model_config'].get('num_message', 4)
        self.num_message = num_hidden
        self.num_agents = model_config['custom_model_config']['num_agents']
        self.obs_space = obs_space
        self.action_space = action_space
        # We divide by num_agents, because the preprocessor flattens the group obs to one vector
        self.num_outputs = num_outputs // self.num_agents
        self.num_inputs = int(np.product(obs_space.shape) // self.num_agents)
        num_hidden = 128
        
        self.obs_encoder = nn.Sequential(
            nn.Linear(self.num_inputs, num_hidden),
            nn.ReLU(),
            nn.Linear(num_hidden, num_hidden),
            nn.ReLU()
        )
        # The actor takes the output of the self.obs_encoder and the other agent's last message as input
        self.actor = nn.Sequential(
            nn.Linear(num_hidden + self.num_message, num_hidden),
            nn.ReLU(),
            nn.Linear(num_hidden, self.num_outputs)
        )
        self.message_module = nn.Sequential(
            nn.Linear(num_hidden + self.num_message, num_hidden),
            nn.ReLU(),
            nn.Linear(num_hidden, self.num_message)
        )
        self.critic = nn.Sequential(
            nn.Linear(self.num_inputs + self.num_outputs + self.num_message, num_hidden),
            nn.ReLU(),
            nn.Linear(num_hidden, 1)
        ) 

    def get_initial_state(self):
        # The state is used to propagate the message to the other agents
        return [torch.zeros(self.num_message)] * self.num_agents

    def forward(self, input_dict, state, seq_len):
        self._last_batch_dim = input_dict['obs_flat'].shape[0]
        # Reshape obs so that each agent is processed independently
        self._last_obs = input_dict['obs_flat'].view((-1, self.num_inputs))
        z = self.obs_encoder(self._last_obs)
        self._last_received_messages = self._generate_received_messages(z)
        z_cat = torch.cat((z, self._last_received_messages), dim=1)
        self._last_actions = self.actor(z_cat)
        self._last_messages = self.message_module(z_cat)
        # The preprocessor splits the actions later on again, hence we have to concatenate the different agent action spaces again.
        return self._last_actions.view(self._last_batch_dim, -1), []
    
    def value_function(self):
        x = torch.cat((self._last_obs, self._last_actions, self._last_received_messages), dim=1)
        return torch.sum(self.critic(x).view(self._last_batch_dim, -1), dim=1)
    
    def _generate_received_messages(self, messages):
        n_messages = messages.shape[0]
        mean_message = torch.mean(messages, dim=0)
        return mean_message.repeat(n_messages, 1)
    
test_env = ParallelPettingZooEnv(simple_spread_v2.parallel_env(N=num_agents))
agents = test_env.par_env.agents
obs_space=Tuple(test_env.observation_space for _ in agents)
act_space=Tuple(test_env.action_space for _ in agents)
env_creator = lambda config={}: ParallelPettingZooEnv(simple_spread_v2.parallel_env(**config)).with_agent_groups({'all': agents}, obs_space=obs_space, act_space=act_space)

ray.init(local_mode=True)

ModelCatalog.register_custom_model("group_model", GroupModel)
tune.register_env("simple_spread", env_creator)

config = (AlgorithmConfig(algorithm)
        .framework(
            framework='torch',
        )
        .environment(env="simple_spread", env_config={'N': num_agents})
        .training(
                train_batch_size=128,
                model={
                    'custom_model': 'group_model',
                    'custom_model_config': {'num_agents': num_agents},
                },
            )
        .rollouts(
            num_rollout_workers=0,
            rollout_fragment_length=32,
            )
)

tune.Tuner(
    algorithm,
    param_space=config,
    run_config=air.RunConfig(
        stop={"training_iteration": 10},
        progress_reporter=CLIReporter(
            metric_columns={ 
                "training_iteration": "iter",
                "time_total_s": "time_total_s",
                "episode_reward_mean": "reward_mean",
            },
            max_report_frequency=10,
        ),
    ),
).fit()
ray.shutdown()