Extremely oscillating MAPPO reward on custom env

Hi everyone,

I am currently working on a custom PettingZoo environment with multiple agents trying to control a spreading wildfire by applying a fire retardant action. Here is a sample of the environment grid-world (the action space is discrete):


The following are my environment reward components (independently for each agent policy) + a final reward equaling the double log of all remaining healthy trees:

    def reward_burning_trees(self, prev_state: np.ndarray, next_state: np.ndarray, agent_position: np.ndarray, other_agents_positions: np.ndarray, action: np.ndarray | int, *args, **kwargs) -> float:
        """negative reward for every newly burning tree in the env"""
        
        
        # to tackle high-vegetation fires first?
        old_burning = (np.isin(prev_state, [TreeCellStateV0.BURNING, TreeCellStateV0.BURNT, TreeCellStateV0.RETARDANT_APPLIED_ON_BURNING]) * self.vegetation_plane).sum()
        new_burning = (np.isin(next_state, [TreeCellStateV0.BURNING, TreeCellStateV0.BURNT, TreeCellStateV0.RETARDANT_APPLIED_ON_BURNING]) * self.vegetation_plane).sum()
        
        severity_penalty = new_burning - old_burning
        
        distance_from_closest_fire = scipy.spatial.distance.cdist(agent_position[np.newaxis, :], self.burning_tree_indices).min()
        
        ## penalize agent even more if he hadn't applied a retardant near a fire
        if distance_from_closest_fire < 3 and action != DroneActionV0.APPLY_FIRE_RETARDANT:
            severity_penalty *= 1.5
        
        return - np.log(severity_penalty + 1).item()
    
    def reward_retardant_proximity(self, prev_state: np.ndarray, next_state: np.ndarray, agent_position: np.ndarray, other_agents_positions: np.ndarray, action: np.ndarray | int, *args, **kwargs):
        """positive reward if the agent applies retardant close to a cell on fire"""
        
        reward = 0
        if action == DroneActionV0.APPLY_FIRE_RETARDANT and prev_state[agent_position[0], agent_position[1]] != TreeCellStateV0.BURNT:
            min_distance = scipy.spatial.distance.cdist(agent_position[np.newaxis,:],self.burning_tree_indices).min()
            ## if the closest tree is closer than 3 units on the map
            if min_distance < 3:
                ## we sum the coverged vegetation area of non-burnt trees to signify the amount of saved vegetation as reward
                affected_area = prev_state[max(0,agent_position[0] -1): min(len(prev_state), agent_position[0] + 2),
                    max(0,agent_position[1] -1): min(len(prev_state), agent_position[1] + 2)]
                ## we only reward vegetation that is not (already burnt or with retardant already applied)
                affected_vegetation = self.vegetation_plane[max(0,agent_position[0] -1): min(len(self.vegetation_plane), agent_position[0] + 2),
                    max(0,agent_position[1] -1): min(len(self.vegetation_plane), agent_position[1] + 2)] * (np.isin(affected_area, [TreeCellStateV0.HEALTHY, TreeCellStateV0.BURNING]))
                
                reward = affected_vegetation.sum()
                
        return reward
            
    
    def reward_agents_proximity(self, prev_state: FireSpreadEnvState, next_state: FireSpreadEnvState, agent_position: np.ndarray, other_agents_positions: np.ndarray, action: np.ndarray | int, safe_distance: float=2.0, *args, **kwargs) -> float:
        """negative reward for agents being closer to each other than safe_distance"""
        
        distance = (np.linalg.norm(agent_position - other_agents_positions, axis=1))
        distance_penalty = ((safe_distance - distance) * (distance < safe_distance)).sum()
        
        
        return - distance_penalty    

I decided to use a dsitributed Multi-agent PPO using the standard PPOConfig as following

"""Uses Ray's RLlib to train agents to play Pistonball.

Author: Rohan (https://github.com/Rohan138)
"""

import os

import ray
from ray import tune
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.env.wrappers.pettingzoo_env import ParallelPettingZooEnv
from ray.rllib.models import ModelCatalog
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.tune.registry import register_env
from torch import nn
from ray.air.integrations.wandb import WandbLoggerCallback


class ImprovedCNNModelV2(TorchModelV2, nn.Module):
    def __init__(self, obs_space, act_space, num_outputs, *args, **kwargs):
        TorchModelV2.__init__(self, obs_space, act_space, num_outputs, *args, **kwargs)
        nn.Module.__init__(self)

        self.conv_layers = nn.Sequential(
            nn.Conv2d(in_channels=6, out_channels=64, kernel_size=8, stride=4),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4, stride=2),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
        )

        self.residual_layers = nn.Sequential(
            nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
        )

        self.fc_layers = nn.Sequential(
            nn.Flatten(),
            nn.LazyLinear(1024),  # Adjust the input size depending on the output size of the conv layers
            nn.ReLU(),
            nn.Dropout(p=0.5),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Dropout(p=0.5),
        )

        self.policy_fn = nn.Linear(512, num_outputs)
        self.value_fn = nn.Linear(512, 1)

    def forward(self, input_dict, state, seq_lens):
        x = self.conv_layers(input_dict["obs"].permute(0, 3, 1, 2))

        # Apply residual connections
        res_x = self.residual_layers(x)
        x = nn.ReLU()(x + res_x)

        # Pass through fully connected layers
        x = self.fc_layers(x)
        
        self._value_out = self.value_fn(x)
        return self.policy_fn(x), state

    def value_function(self):
        return self._value_out.flatten()


def env_creator(args):
    env = (FireControlOnlyEnvV0(grid_size=40, fire_deadzone=18, n_agents=2))
    return env


if __name__ == "__main__":
    ray.init()

    env_name = "wildfire_control_v0"

    register_env(env_name, lambda config: ParallelPettingZooEnv(env_creator({})))
    ModelCatalog.register_custom_model("CNNModelV2", ImprovedCNNModelV2)
    test_env =env_creator({})
    obs_space = test_env.observation_space
    act_space = test_env.action_space
    config = (
        PPOConfig()
        .environment(env=env_name, clip_actions=True)
        .rollouts(num_rollout_workers=3, rollout_fragment_length='auto')
        .training(
            model={
                "custom_model": "CNNModelV2"
            },
            train_batch_size=512,
            gamma=0.99,
            lambda_=0.9,
            use_gae=True,
            lr=tune.loguniform(1e-5, 1e-3),
            clip_param=tune.uniform(0.2, 0.4),
            grad_clip=None,
            entropy_coeff=tune.uniform(0.01, 0.3),
            vf_loss_coeff=0.25,
            sgd_minibatch_size=64,
            num_sgd_iter=tune.choice([5, 10, 20]),
        ) .multi_agent(
            policies={
                agent: (None, obs_space, act_space, {})
                for agent in test_env.unwrapped.agents
            },
            policy_mapping_fn=(lambda agent_id, *args, **kwargs: agent_id),
        )
        .debugging(log_level="ERROR")
        .framework(framework="torch")
        .resources(num_gpus=int(1))
    )
    from kaggle_secrets import UserSecretsClient
    user_secrets = UserSecretsClient()
    secret_value_0 = user_secrets.get_secret("WANDB_KEY")
    results = tune.run(
        "PPO",
        name="PPO",
        stop={"timesteps_total": 500000 },
        checkpoint_freq=10,
        local_dir="/kaggle/working/ray_results/" + env_name,
        config=config.to_dict(),
        num_samples=3,
        callbacks=[
            WandbLoggerCallback(project="FireFlyTest", group="*******", api_key=secret_value_0)
        ]
    )

After almost a total of 1.5M I am facing a highly-oscillating reward that does not improve in the slightest:

I am mostly suspicious of the reward formulation.

Appreciating any tips or general guidance!