Rllib and airsim - multi agent reinforcement learning (MAPPO) observation space issue

1. Severity of the issue: (select one)
Medium: Significantly affects my productivity but can find a workaround.

2. Environment:

  • Ray version: 2.44.1
  • Python version: 3.9.13

Hi everyone,

I have been developing multi agent reinforcement learning (MAPPO) system using Ray RLlib, with a centralized critic, for a multi-agent AirSim environment.

I’ve been working on configuring the action and observation spaces, but it seems that somewhere in the loop, the observation space is not returning the correct shape or type.

Below, I’ve included the main training loop, environment initialization, and the observation function defined in the environment script.

I appreciate any comments or suggestions you can provide.

Thank you,

Main training loop


  from airgym.envs.rllib_multiagent_drone_env_v4 import MultiAgentAirSimEnv
  from ray.rllib.env.wrappers.pettingzoo_env import ParallelPettingZooEnv
  
  import ray
  from ray import tune
  from ray.rllib.algorithms.ppo import PPOConfig
  from ray.tune.registry import register_env
  from ray.rllib.env.multi_agent_env import make_multi_agent
  from ray.rllib.policy.policy import PolicySpec
  
  import torch
  import torch.nn as nn
  import numpy as np
  
  from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
  from ray.rllib.models.modelv2 import ModelV2
  from ray.rllib.models.torch.fcnet import FullyConnectedNetwork
  from ray.rllib.utils.annotations import override
  from ray.rllib.models.torch.misc import SlimFC
  from ray.rllib.utils.framework import try_import_torch
  from ray.rllib.utils.typing import ModelConfigDict, TensorType
  from typing import Dict, List, Union
  
  torch, nn = try_import_torch()
  
  class CentralizedCriticModel(TorchModelV2, nn.Module):
      def __init__(self, obs_space, action_space, num_outputs, model_config: ModelConfigDict, name: str):
          TorchModelV2.__init__(self, obs_space, action_space, num_outputs, model_config, name)
          nn.Module.__init__(self)
  
          image_shape = obs_space["image"].shape
          global_state_shape = obs_space["global_state"].shape[0]
  
          self.actor_conv = nn.Sequential(
              nn.Conv2d(image_shape[2], 32, 8, stride=4),
              nn.ReLU(),
              nn.Conv2d(32, 64, 4, stride=2),
              nn.ReLU(),
              nn.Conv2d(64, 64, 3, stride=1),
              nn.ReLU(),
              nn.Flatten(),
              nn.Linear(3136, 512),
              nn.ReLU(),
          )
          self.logits = nn.Linear(512, num_outputs)
  
          self.critic_mlp = nn.Sequential(
              nn.Linear(global_state_shape, 256),
              nn.ReLU(),
              nn.Linear(256, 256),
              nn.ReLU(),
              nn.Linear(256, 1),
          )
  
          self._value_out = None
  
      @override(ModelV2)
      def forward(self, input_dict, state, seq_lens):
          obs = input_dict["obs"]
  
          obs_image = obs["image"].float().permute(0, 3, 1, 2) / 255.0
  
          # global state for centralized critic
          self._obs_global_state = obs["global_state"].float()  # [B, 6]
  
          features = self.actor_conv(obs_image)  # [B, 512]
          logits = self.logits(features)  # [B, num_outputs]
  
          return logits, state
  
      @override(ModelV2)
      def value_function(self):
  
          out = self.critic_mlp(self._obs_global_state)
          return torch.reshape(out, [-1])  # Return as a 1D tensor (batch_size,)
  
  from ray.rllib.models import ModelCatalog
  
  ModelCatalog.register_custom_model("centralized_critic", CentralizedCriticModel)
  
  class Wrapper:
      def __init__(self, env):
          self.env = env
          self.possible_agents = env.possible_agents
          self.metadata = getattr(env, "metadata", {})
  
          self.observation_space = lambda agent_id: env.observation_space(agent_id)
          self.action_space = lambda agent_id: env.action_space[agent_id]
  
          self.reset = env.reset
          self.step = env.step
  
          self.agents = env.possible_agents 
  
          self.__getattr__ = lambda name: getattr(self.env, name)
  
  
  def env_creator(config):
      raw_env = MultiAgentAirSimEnv(ip_address="127.0.0.1", step_length=1, image_shape=(84, 84, 1), n_agents=2)
      wrapped_env = Wrapper(raw_env)
      return ParallelPettingZooEnv(wrapped_env)
  
  
  register_env("airsim_multiagent", env_creator)
  
  def policy_mapping_fn(agent_id, episode, worker, **kwargs):
      return "shared_policy"
  
  config = (
      PPOConfig()
      .api_stack(enable_rl_module_and_learner=False, enable_env_runner_and_connector_v2=False)  # Disable new API stack
      .environment("airsim_multiagent")
      .framework("torch")
      .training(
          model={"custom_model": "centralized_critic","vf_share_layers": False},
          train_batch_size=4000,  # This is the main batch size
          gamma=0.99,
          lambda_=0.95,
          vf_clip_param=10.0,
          clip_param=0.2,
          entropy_coeff=0.01,
          num_sgd_iter=10,  # Number of SGD iterations
          use_gae=True,
      )
      .multi_agent(
          policies={
              "shared_policy": PolicySpec(),
          },
          policy_mapping_fn=policy_mapping_fn,
          policies_to_train=["shared_policy"],
      )
      .env_runners(num_env_runners=1)  
  )
  
  
  ray.init(ignore_reinit_error=True)

Environment Initialization


import setup_path
import airsim
import numpy as np
import math
import time

import sys
import cv2
from PIL import Image

import logging
logger = logging.getLogger(__name__)

import gymnasium as gym
from gymnasium import spaces
from gymnasium.spaces import Discrete, MultiDiscrete, Box, Tuple
from gymnasium.utils import EzPickle

from pettingzoo import ParallelEnv

from pettingzoo import AECEnv
from pettingzoo.utils import agent_selector, wrappers

from pettingzoo.utils.env import ObsDict, ActionDict
import random

from collections import defaultdict
import json


class MultiAgentAirSimEnv(ParallelEnv):
    metadata = {"render.modes": ["human"], "name": "airsim-drone-sample-v0"}

    def __init__(self, ip_address="127.0.0.1", step_length=1, image_shape=(84, 84, 1), n_agents=2):
        EzPickle.__init__(self,ip_address, step_length, image_shape, n_agents)

        self.render_mode = "human"
        self.step_length = step_length
        self.image_shape = image_shape

        self.current_step = 0
        self.n_agents = n_agents

        self.drone = airsim.MultirotorClient(ip=ip_address)

        self.image_request = airsim.ImageRequest(
            3, airsim.ImageType.DepthPerspective, True, False
        )

        self.possible_agents = [f"Drone{i}" for i in range(n_agents)]
        self.agents = self.possible_agents[:]


        self.action_space = {agent: spaces.Discrete(3) for agent in self.possible_agents}


        image_space = Box(low=0, high=255, shape=image_shape, dtype=np.uint8)
        global_state_space = Box(low=-np.inf, high=np.inf, shape=(n_agents * 3,), dtype=np.float32)


        observation_space = Tuple([image_space, global_state_space])


        self.observation_spaces = {
            agent: observation_space for agent in self.possible_agents
        }
        self.action_spaces = {
            agent: self.action_space[agent] for agent in self.possible_agents
        }

        self.episodeN = 0

        self._setup_flight()

        self.training_mode= True

Observation Function


    def _get_obs(self):
        observations = {}
        global_state = [] 

        for agent in self.agents:
            responses = self.drone.simGetImages([self.image_request], vehicle_name=agent)
            image = self.transform_obs(responses)  # Convert AirSim response to NumPy array


            if image.ndim == 3: 
                image = np.transpose(image, (2, 1, 0))


            if image.shape[0] == 1: 
                image = np.squeeze(image)  
                image = np.expand_dims(image, axis=-1) 

            elif image.shape[0] == 3:  
                image = np.transpose(image, (1, 2, 0))  

            agent_obs = {"image": image}


            pos = self.drone.simGetGroundTruthKinematics(vehicle_name=agent).position
            global_state_agent = np.array([pos.x_val, pos.y_val, pos.z_val], dtype=np.float32)

            agent_obs = {
                "image": image.astype(np.float32),
                "global_state": global_state_agent
            }
            observations[agent] = agent_obs

        return observations