1. Severity of the issue: (select one)
Medium: Significantly affects my productivity but can find a workaround.
2. Environment:
- Ray version: 2.44.1
- Python version: 3.9.13
Hi everyone,
I have been developing multi agent reinforcement learning (MAPPO) system using Ray RLlib, with a centralized critic, for a multi-agent AirSim environment.
I’ve been working on configuring the action and observation spaces, but it seems that somewhere in the loop, the observation space is not returning the correct shape or type.
Below, I’ve included the main training loop, environment initialization, and the observation function defined in the environment script.
I appreciate any comments or suggestions you can provide.
Thank you,
Main training loop
from airgym.envs.rllib_multiagent_drone_env_v4 import MultiAgentAirSimEnv
from ray.rllib.env.wrappers.pettingzoo_env import ParallelPettingZooEnv
import ray
from ray import tune
from ray.rllib.algorithms.ppo import PPOConfig
from ray.tune.registry import register_env
from ray.rllib.env.multi_agent_env import make_multi_agent
from ray.rllib.policy.policy import PolicySpec
import torch
import torch.nn as nn
import numpy as np
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.torch.fcnet import FullyConnectedNetwork
from ray.rllib.utils.annotations import override
from ray.rllib.models.torch.misc import SlimFC
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.typing import ModelConfigDict, TensorType
from typing import Dict, List, Union
torch, nn = try_import_torch()
class CentralizedCriticModel(TorchModelV2, nn.Module):
def __init__(self, obs_space, action_space, num_outputs, model_config: ModelConfigDict, name: str):
TorchModelV2.__init__(self, obs_space, action_space, num_outputs, model_config, name)
nn.Module.__init__(self)
image_shape = obs_space["image"].shape
global_state_shape = obs_space["global_state"].shape[0]
self.actor_conv = nn.Sequential(
nn.Conv2d(image_shape[2], 32, 8, stride=4),
nn.ReLU(),
nn.Conv2d(32, 64, 4, stride=2),
nn.ReLU(),
nn.Conv2d(64, 64, 3, stride=1),
nn.ReLU(),
nn.Flatten(),
nn.Linear(3136, 512),
nn.ReLU(),
)
self.logits = nn.Linear(512, num_outputs)
self.critic_mlp = nn.Sequential(
nn.Linear(global_state_shape, 256),
nn.ReLU(),
nn.Linear(256, 256),
nn.ReLU(),
nn.Linear(256, 1),
)
self._value_out = None
@override(ModelV2)
def forward(self, input_dict, state, seq_lens):
obs = input_dict["obs"]
obs_image = obs["image"].float().permute(0, 3, 1, 2) / 255.0
# global state for centralized critic
self._obs_global_state = obs["global_state"].float() # [B, 6]
features = self.actor_conv(obs_image) # [B, 512]
logits = self.logits(features) # [B, num_outputs]
return logits, state
@override(ModelV2)
def value_function(self):
out = self.critic_mlp(self._obs_global_state)
return torch.reshape(out, [-1]) # Return as a 1D tensor (batch_size,)
from ray.rllib.models import ModelCatalog
ModelCatalog.register_custom_model("centralized_critic", CentralizedCriticModel)
class Wrapper:
def __init__(self, env):
self.env = env
self.possible_agents = env.possible_agents
self.metadata = getattr(env, "metadata", {})
self.observation_space = lambda agent_id: env.observation_space(agent_id)
self.action_space = lambda agent_id: env.action_space[agent_id]
self.reset = env.reset
self.step = env.step
self.agents = env.possible_agents
self.__getattr__ = lambda name: getattr(self.env, name)
def env_creator(config):
raw_env = MultiAgentAirSimEnv(ip_address="127.0.0.1", step_length=1, image_shape=(84, 84, 1), n_agents=2)
wrapped_env = Wrapper(raw_env)
return ParallelPettingZooEnv(wrapped_env)
register_env("airsim_multiagent", env_creator)
def policy_mapping_fn(agent_id, episode, worker, **kwargs):
return "shared_policy"
config = (
PPOConfig()
.api_stack(enable_rl_module_and_learner=False, enable_env_runner_and_connector_v2=False) # Disable new API stack
.environment("airsim_multiagent")
.framework("torch")
.training(
model={"custom_model": "centralized_critic","vf_share_layers": False},
train_batch_size=4000, # This is the main batch size
gamma=0.99,
lambda_=0.95,
vf_clip_param=10.0,
clip_param=0.2,
entropy_coeff=0.01,
num_sgd_iter=10, # Number of SGD iterations
use_gae=True,
)
.multi_agent(
policies={
"shared_policy": PolicySpec(),
},
policy_mapping_fn=policy_mapping_fn,
policies_to_train=["shared_policy"],
)
.env_runners(num_env_runners=1)
)
ray.init(ignore_reinit_error=True)
Environment Initialization
import setup_path
import airsim
import numpy as np
import math
import time
import sys
import cv2
from PIL import Image
import logging
logger = logging.getLogger(__name__)
import gymnasium as gym
from gymnasium import spaces
from gymnasium.spaces import Discrete, MultiDiscrete, Box, Tuple
from gymnasium.utils import EzPickle
from pettingzoo import ParallelEnv
from pettingzoo import AECEnv
from pettingzoo.utils import agent_selector, wrappers
from pettingzoo.utils.env import ObsDict, ActionDict
import random
from collections import defaultdict
import json
class MultiAgentAirSimEnv(ParallelEnv):
metadata = {"render.modes": ["human"], "name": "airsim-drone-sample-v0"}
def __init__(self, ip_address="127.0.0.1", step_length=1, image_shape=(84, 84, 1), n_agents=2):
EzPickle.__init__(self,ip_address, step_length, image_shape, n_agents)
self.render_mode = "human"
self.step_length = step_length
self.image_shape = image_shape
self.current_step = 0
self.n_agents = n_agents
self.drone = airsim.MultirotorClient(ip=ip_address)
self.image_request = airsim.ImageRequest(
3, airsim.ImageType.DepthPerspective, True, False
)
self.possible_agents = [f"Drone{i}" for i in range(n_agents)]
self.agents = self.possible_agents[:]
self.action_space = {agent: spaces.Discrete(3) for agent in self.possible_agents}
image_space = Box(low=0, high=255, shape=image_shape, dtype=np.uint8)
global_state_space = Box(low=-np.inf, high=np.inf, shape=(n_agents * 3,), dtype=np.float32)
observation_space = Tuple([image_space, global_state_space])
self.observation_spaces = {
agent: observation_space for agent in self.possible_agents
}
self.action_spaces = {
agent: self.action_space[agent] for agent in self.possible_agents
}
self.episodeN = 0
self._setup_flight()
self.training_mode= True
Observation Function
def _get_obs(self):
observations = {}
global_state = []
for agent in self.agents:
responses = self.drone.simGetImages([self.image_request], vehicle_name=agent)
image = self.transform_obs(responses) # Convert AirSim response to NumPy array
if image.ndim == 3:
image = np.transpose(image, (2, 1, 0))
if image.shape[0] == 1:
image = np.squeeze(image)
image = np.expand_dims(image, axis=-1)
elif image.shape[0] == 3:
image = np.transpose(image, (1, 2, 0))
agent_obs = {"image": image}
pos = self.drone.simGetGroundTruthKinematics(vehicle_name=agent).position
global_state_agent = np.array([pos.x_val, pos.y_val, pos.z_val], dtype=np.float32)
agent_obs = {
"image": image.astype(np.float32),
"global_state": global_state_agent
}
observations[agent] = agent_obs
return observations