Extremely oscillating MAPPO reward on custom env

Hi everyone,

I am currently working on a custom PettingZoo environment with multiple agents trying to control a spreading wildfire by applying a fire retardant action. Here is a sample of the environment grid-world (the action space is discrete):


The following are my environment reward components (independently for each agent policy) + a final reward equaling the double log of all remaining healthy trees:

    def reward_burning_trees(self, prev_state: np.ndarray, next_state: np.ndarray, agent_position: np.ndarray, other_agents_positions: np.ndarray, action: np.ndarray | int, *args, **kwargs) -> float:
        """negative reward for every newly burning tree in the env"""
        
        
        # to tackle high-vegetation fires first?
        old_burning = (np.isin(prev_state, [TreeCellStateV0.BURNING, TreeCellStateV0.BURNT, TreeCellStateV0.RETARDANT_APPLIED_ON_BURNING]) * self.vegetation_plane).sum()
        new_burning = (np.isin(next_state, [TreeCellStateV0.BURNING, TreeCellStateV0.BURNT, TreeCellStateV0.RETARDANT_APPLIED_ON_BURNING]) * self.vegetation_plane).sum()
        
        severity_penalty = new_burning - old_burning
        
        distance_from_closest_fire = scipy.spatial.distance.cdist(agent_position[np.newaxis, :], self.burning_tree_indices).min()
        
        ## penalize agent even more if he hadn't applied a retardant near a fire
        if distance_from_closest_fire < 3 and action != DroneActionV0.APPLY_FIRE_RETARDANT:
            severity_penalty *= 1.5
        
        return - np.log(severity_penalty + 1).item()
    
    def reward_retardant_proximity(self, prev_state: np.ndarray, next_state: np.ndarray, agent_position: np.ndarray, other_agents_positions: np.ndarray, action: np.ndarray | int, *args, **kwargs):
        """positive reward if the agent applies retardant close to a cell on fire"""
        
        reward = 0
        if action == DroneActionV0.APPLY_FIRE_RETARDANT and prev_state[agent_position[0], agent_position[1]] != TreeCellStateV0.BURNT:
            min_distance = scipy.spatial.distance.cdist(agent_position[np.newaxis,:],self.burning_tree_indices).min()
            ## if the closest tree is closer than 3 units on the map
            if min_distance < 3:
                ## we sum the coverged vegetation area of non-burnt trees to signify the amount of saved vegetation as reward
                affected_area = prev_state[max(0,agent_position[0] -1): min(len(prev_state), agent_position[0] + 2),
                    max(0,agent_position[1] -1): min(len(prev_state), agent_position[1] + 2)]
                ## we only reward vegetation that is not (already burnt or with retardant already applied)
                affected_vegetation = self.vegetation_plane[max(0,agent_position[0] -1): min(len(self.vegetation_plane), agent_position[0] + 2),
                    max(0,agent_position[1] -1): min(len(self.vegetation_plane), agent_position[1] + 2)] * (np.isin(affected_area, [TreeCellStateV0.HEALTHY, TreeCellStateV0.BURNING]))
                
                reward = affected_vegetation.sum()
                
        return reward
            
    
    def reward_agents_proximity(self, prev_state: FireSpreadEnvState, next_state: FireSpreadEnvState, agent_position: np.ndarray, other_agents_positions: np.ndarray, action: np.ndarray | int, safe_distance: float=2.0, *args, **kwargs) -> float:
        """negative reward for agents being closer to each other than safe_distance"""
        
        distance = (np.linalg.norm(agent_position - other_agents_positions, axis=1))
        distance_penalty = ((safe_distance - distance) * (distance < safe_distance)).sum()
        
        
        return - distance_penalty    

I decided to use a dsitributed Multi-agent PPO using the standard PPOConfig as following

"""Uses Ray's RLlib to train agents to play Pistonball.

Author: Rohan (https://github.com/Rohan138)
"""

import os

import ray
from ray import tune
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.env.wrappers.pettingzoo_env import ParallelPettingZooEnv
from ray.rllib.models import ModelCatalog
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.tune.registry import register_env
from torch import nn
from ray.air.integrations.wandb import WandbLoggerCallback


class ImprovedCNNModelV2(TorchModelV2, nn.Module):
    def __init__(self, obs_space, act_space, num_outputs, *args, **kwargs):
        TorchModelV2.__init__(self, obs_space, act_space, num_outputs, *args, **kwargs)
        nn.Module.__init__(self)

        self.conv_layers = nn.Sequential(
            nn.Conv2d(in_channels=6, out_channels=64, kernel_size=8, stride=4),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4, stride=2),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
        )

        self.residual_layers = nn.Sequential(
            nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
        )

        self.fc_layers = nn.Sequential(
            nn.Flatten(),
            nn.LazyLinear(1024),  # Adjust the input size depending on the output size of the conv layers
            nn.ReLU(),
            nn.Dropout(p=0.5),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Dropout(p=0.5),
        )

        self.policy_fn = nn.Linear(512, num_outputs)
        self.value_fn = nn.Linear(512, 1)

    def forward(self, input_dict, state, seq_lens):
        x = self.conv_layers(input_dict["obs"].permute(0, 3, 1, 2))

        # Apply residual connections
        res_x = self.residual_layers(x)
        x = nn.ReLU()(x + res_x)

        # Pass through fully connected layers
        x = self.fc_layers(x)
        
        self._value_out = self.value_fn(x)
        return self.policy_fn(x), state

    def value_function(self):
        return self._value_out.flatten()


def env_creator(args):
    env = (FireControlOnlyEnvV0(grid_size=40, fire_deadzone=18, n_agents=2))
    return env


if __name__ == "__main__":
    ray.init()

    env_name = "wildfire_control_v0"

    register_env(env_name, lambda config: ParallelPettingZooEnv(env_creator({})))
    ModelCatalog.register_custom_model("CNNModelV2", ImprovedCNNModelV2)
    test_env =env_creator({})
    obs_space = test_env.observation_space
    act_space = test_env.action_space
    config = (
        PPOConfig()
        .environment(env=env_name, clip_actions=True)
        .rollouts(num_rollout_workers=3, rollout_fragment_length='auto')
        .training(
            model={
                "custom_model": "CNNModelV2"
            },
            train_batch_size=512,
            gamma=0.99,
            lambda_=0.9,
            use_gae=True,
            lr=tune.loguniform(1e-5, 1e-3),
            clip_param=tune.uniform(0.2, 0.4),
            grad_clip=None,
            entropy_coeff=tune.uniform(0.01, 0.3),
            vf_loss_coeff=0.25,
            sgd_minibatch_size=64,
            num_sgd_iter=tune.choice([5, 10, 20]),
        ) .multi_agent(
            policies={
                agent: (None, obs_space, act_space, {})
                for agent in test_env.unwrapped.agents
            },
            policy_mapping_fn=(lambda agent_id, *args, **kwargs: agent_id),
        )
        .debugging(log_level="ERROR")
        .framework(framework="torch")
        .resources(num_gpus=int(1))
    )
    from kaggle_secrets import UserSecretsClient
    user_secrets = UserSecretsClient()
    secret_value_0 = user_secrets.get_secret("WANDB_KEY")
    results = tune.run(
        "PPO",
        name="PPO",
        stop={"timesteps_total": 500000 },
        checkpoint_freq=10,
        local_dir="/kaggle/working/ray_results/" + env_name,
        config=config.to_dict(),
        num_samples=3,
        callbacks=[
            WandbLoggerCallback(project="FireFlyTest", group="*******", api_key=secret_value_0)
        ]
    )

After almost a total of 1.5M I am facing a highly-oscillating reward that does not improve in the slightest:

I am mostly suspicious of the reward formulation.

Appreciating any tips or general guidance!

In general, the strategy I’ve seen for debugging agents failing to learn like this (they seem to have a higher reward in the first epoch than any epoch thereafter, unless that’s an artifact of your visualization function) is to simplify the environment dramatically until the issue becomes clearer. RL has lots of moving parts, and figuring out what’s causing a complex system not to work at all, while easier than identifying little issues that cause it to work but ‘less well’, is still quite tricky. I would recommend stripping out the multiagent part, shrinking the field size down to a toy problem, and seeing if it works there. If it does, see if you can find where it stops. Simplifying the reward formulation is also something to do - I’ve generally only seen reward-shaping done after a “successful-ish” run with a naive loss function.

As a side note, MAPPO is a different thing than multi-agent independent PPO (IPPO). I’ve actually implemented it for something I’m working on, you just add a shared critic.

Edit: Sorry to readers for the necropost, I clicked ‘recent’ and saw August, hadn’t noticed it was in the wrong year).