Errors wiring marl agents with custom petting zoo env

How severe does this issue affect your experience of using Ray?
High: It blocks me to complete my task.


HIGH LEVEL DESCRIPTION:
'm trying to wire together a custom PettingZoo environment and multiple PPO agents. Depending on what I try, I either get an error saying:

AttributeError: 'MultiAgentRLModuleConfig' object has no attribute 'get_catalog'

or

assert self.env and self.module

The latter stemming from this assertion failing.


SETUP:
Custom PettingZoo env:

from copy import copy
import pygame

import numpy as np
from pettingzoo import ParallelEnv
from gymnasium import spaces

class MultiGridWorld(ParallelEnv):

    metadata = {
        "render_modes": ["human", "rgb_array"],
        "render_fps": 4,
        "name": "multi_grid_world_v0"
    }

    def __init__(self, render_mode=None, grid_size=5):
        super().__init__()

        self.possible_agents = ["red", "blue"]
        self.agents = copy(self.possible_agents)

        self.grid_size = grid_size
        self.window_size = 512

        # Action and observation space for each agent (it could differ between them)
        self.action_spaces = {name: spaces.Discrete(4) for name in self.possible_agents}

        # Observations are dictionaries with the agent's and the target's location.
        # Each location is encoded as an element of {0, ..., `size`}^2, i.e. MultiDiscrete([size, size]).
        self.observation_spaces = {
            name: spaces.Dict(
                {
                    "red": spaces.Box(0, self.grid_size - 1, shape=(2,), dtype=int),
                    "flag": spaces.Box(0, self.grid_size - 1, shape=(2,), dtype=int),
                    "blue": spaces.Box(0, self.grid_size - 1, shape=(2,), dtype=int),
                }
            )
            for name in self.possible_agents
        }

        """
        The following dictionary maps abstract actions from `self.action_space` to
        the direction we will walk in if that action is taken.
        I.e. 0 corresponds to "right", 1 to "up" etc.
        """
        self._action_to_direction = {
            0: np.array([1,0]),
            1: np.array([0,1]),
            2: np.array([-1,0]),
            3: np.array([0,-1]),
        }

        assert render_mode is None or render_mode in self.metadata["render_modes"]
        self.render_mode = render_mode

        """
        If human-rendering is used, `self.window` will be a reference
        to the window that we draw to. `self.clock` will be a clock that is used
        to ensure that the environment is rendered at the correct framerate in
        human-mode. They will remain `None` until human-mode is used for the
        first time.
        """
        self.window = None
        self.clock = None

        self._red_location = np.random.randint(0, self.grid_size, size=2, dtype=int)
        self._flag_location = np.random.randint(0, self.grid_size, size=2, dtype=int)
        self._blue_location = np.random.randint(0, self.grid_size, size=2, dtype=int)

    def reset(self, seed=None, options=None):

        self.agents = copy(self.possible_agents)

        # Initialize a random starting location for agents
        self._red_location = np.random.randint(0, self.grid_size, size=2, dtype=int)

        # We will sample the flags's location randomly until it does not coincide with the agent's location
        self._flag_location = self._red_location
        while np.array_equal(self._flag_location, self._red_location):
            self._flag_location = np.random.randint(0, self.grid_size, size=2, dtype=int)
        
        # We will sample the blue's location randomly until it does not coincide with the agents or flags location
        self._blue_location = self._red_location
        while np.array_equal(self._blue_location, self._red_location) or np.array_equal(self._blue_location, self._flag_location):
            self._blue_location = np.random.randint(0, self.grid_size, size=2, dtype=int)
        
        observations = {
            name:   {
                "red": self._red_location,
                "flag": self._flag_location,
                "blue": self._blue_location
            }
            for name in self.possible_agents
        }

        infos = {a: {} for a in self.agents}

        return observations, infos

    def step(self, actions):
        # Map the action (element of {0,1,2,3}) to the direction we walk in
        red_action = actions["red"] # int
        red_direction = self._action_to_direction[red_action]
        # We use `np.clip` to make sure we don't leave the grid
        self._red_location = np.clip(self._red_location + red_direction, 0, self.grid_size - 1)

        blue_action = actions["blue"] # int
        blue_direction = self._action_to_direction[blue_action]
        # We use `np.clip` to make sure we don't leave the grid
        self._blue_location = np.clip(self._blue_location + blue_direction, 0, self.grid_size - 1)

        # An episode is done iff red gets the flag or blue captures red
        terminations = {a: False for a in self.agents}
        rewards = {a: 0 for a in self.agents}
        truncations = {a: False for a in self.agents}

        if np.array_equal(self._red_location, self._flag_location):
            terminations = {a: True for a in self.agents}
            rewards["red"] = 1
            rewards["blue"] = -1
        
        if np.array_equal(self._red_location, self._blue_location):
            terminations = {a: True for a in self.agents}
            rewards["red"] = -1
            rewards["blue"] = 1
        
        observations = {
            name:   {
                "red": self._red_location,
                "flag": self._flag_location,
                "blue": self._blue_location,
            }
            for name in self.possible_agents
        }

        infos = {a: {} for a in self.agents}

        if any(terminations.values()) or all(truncations.values()):
            self.agents = []
            #self.agents = ["red"]
        
        return observations, rewards, terminations, truncations, infos

    def render(self):
        self._render_frame()

    def observation_space(self, agent):
        return self.observation_spaces[agent]
    
    def action_space(self, agent):
        return self.action_spaces[agent]
    
    def _render_frame(self):
        if self.window is None and self.render_mode == "human":
            pygame.init()
            pygame.display.init()
            self.window = pygame.display.set_mode(
                (self.window_size, self.window_size)
            )
        if self.clock is None and self.render_mode == "human":
            self.clock = pygame.time.Clock()

        canvas = pygame.Surface((self.window_size, self.window_size))
        canvas.fill((255, 255, 255))
        pix_square_size = (
            self.window_size / self.grid_size
        )  # The size of a single grid square in pixels

        # First we draw the target
        pygame.draw.rect(
            canvas,
            (0, 0, 0),
            pygame.Rect(
                pix_square_size * self._flag_location,
                (pix_square_size, pix_square_size),
            ),
        )
        # Now we draw the agent
        pygame.draw.circle(
            canvas,
            (255, 0, 0),
            (self._red_location + 0.5) * pix_square_size,
            pix_square_size / 3,
        )
        # Now we draw the blue agent
        pygame.draw.circle(
            canvas,
            (0, 0, 255),
            (self._blue_location + 0.5) * pix_square_size,
            pix_square_size / 3,
        )

        # Finally, add some gridlines
        for x in range(self.grid_size + 1):
            pygame.draw.line(
                canvas,
                0,
                (0, pix_square_size * x),
                (self.window_size, pix_square_size * x),
                width=3,
            )
            pygame.draw.line(
                canvas,
                0,
                (pix_square_size * x, 0),
                (pix_square_size * x, self.window_size),
                width=3,
            )

        if self.render_mode == "human":
            # The following line copies our drawings from `canvas` to the visible window
            self.window.blit(canvas, canvas.get_rect())
            pygame.event.pump()
            pygame.display.update()

            # We need to ensure that human-rendering occurs at the predefined framerate.
            # The following line will automatically add a delay to keep the framerate stable.
            self.clock.tick(self.metadata["render_fps"])
        else:  # rgb_array
            return np.transpose(
                np.array(pygame.surfarray.pixels3d(canvas)), axes=(1, 0, 2)
            )
        
        def close(self):
            if self.window is not None:
                pygame.display.quit()
                pygame.quit()

MAIN

import os

from poc_gym.multi_grid_world import MultiGridWorld
from stable_baselines3.ppo import MlpPolicy

import supersuit as ss

from pettingzoo.utils import parallel_to_aec

import ray
from ray import tune
from ray.tune.registry import register_env
from ray.rllib.env.wrappers.pettingzoo_env import PettingZooEnv
from ray.rllib.models import ModelCatalog
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.algorithms.ppo import PPO


from ray.rllib.core.rl_module.marl_module import MultiAgentRLModuleSpec
from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec

from gymnasium import spaces

from ray.rllib.connectors.env_to_module import FlattenObservations
from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import PPOTorchRLModule
from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog

def run(train: bool = False, save_file_name: str = None):
    if train:
        ray.init()
        env_name = save_file_name
        env = parallel_to_aec(MultiGridWorld())
        register_env(env_name, lambda config: PettingZooEnv(env))
        ModelCatalog.register_custom_model(env_name, MultiGridWorld)

        policies = {"red", "blue"}

        spec = MultiAgentRLModuleSpec(
            module_specs={
                    "red": SingleAgentRLModuleSpec(
                        observation_space=spaces.utils.flatten_space(env.observation_space("red")),
                        action_space={name: spaces.Discrete(4) for name in policies},
                        model_config_dict={
                            "fcnet_hiddens": [256]
                            },
                    ),
                    "blue": SingleAgentRLModuleSpec(
                        observation_space=spaces.utils.flatten_space(env.observation_space("blue")),
                        action_space={name: spaces.Discrete(4) for name in policies},
                        model_config_dict={"fcnet_hiddens": [256]},
                    ),
                },
        )


        config = (
            PPOConfig()
                .environment(env=env_name, clip_actions=True, env_config={"num_agents": len(policies)})
                # Switch both the new API stack flags to True (both False by default).
                # This enables the use of
                # a) RLModule (replaces ModelV2) and Learner (replaces Policy)
                # b) and automatically picks the correct EnvRunner (single-agent vs multi-agent)
                # and enables ConnectorV2 support.
                .api_stack(
                    enable_rl_module_and_learner=True,
                    enable_env_runner_and_connector_v2=True,
                )

                .framework("torch")

                .learners(
                    num_learners=1,  # <- in most cases, set this value to the number of GPUs
                    num_gpus_per_learner=1,  # <- set this to 1, if you have at least 1 GPU
                )

                .rl_module(
                    model_config_dict={"fcnet_hiddens": [256, 256]},
                    rl_module_spec=spec,
                )

                # Because you are in a multi-agent env, you have to set up the usual multi-agent
                # parameters:
                .multi_agent(
                    policies=policies,
                    # Exact 1:1 mapping from AgentID to ModuleID.
                    policy_mapping_fn=(lambda aid, *args, **kwargs: aid),
                )

                .env_runners(
                    num_env_runners=1,
                    num_envs_per_env_runner=1,
                    rollout_fragment_length="auto",
                    env_to_module_connector=lambda env: FlattenObservations(multi_agent=True),
                )

                .debugging(log_level="ERROR")
                .framework(framework="torch")

                # Note: uses_new_env_runners
                .training(
                    train_batch_size=512,
                    lr=2e-5,
                    gamma=0.99,
                    lambda_=0.9,
                    use_gae=True,
                    clip_param=0.4,
                    grad_clip=None,
                    entropy_coeff=0.1,
                    vf_loss_coeff=0.25,
                    sgd_minibatch_size=64,
                    num_sgd_iter=10,
                    model={"uses_new_env_runners": True}
                )
        )
        
        tune.run(
            "PPO",
            name="PPO",
            stop={"timesteps_total": 500000},
            checkpoint_freq=20,
            config=config.to_dict(),
        )

    else:
        ray.init()
        env_name = save_file_name
        env = parallel_to_aec(MultiGridWorld(render_mode="human"))
        register_env(env_name, lambda config: PettingZooEnv(env))

        # We pick up the policies later when sampling the action
        ppo_agent = PPO.from_checkpoint('/home/gulli/ray_results/PPO/PPO_multi_grid_world_championship_v0_14eb8_00000_0_2024-08-09_18-59-11/checkpoint_000047')

        reward_sum = 0
        frame_list = []
        i = 0
        env.reset()

        while True:

            for agent in env.agent_iter():
                observation, reward, termination, truncation, info = env.last()
                reward_sum += reward

                env.render()

                if termination or truncation:
                    #action = None
                    print(f"{agent}: {reward}")
                    env.reset()
                else:
                    if agent == "red":
                        action = ppo_agent.compute_single_action(observation=observation, policy_id="red")
                    else:
                        action = ppo_agent.compute_single_action(observation=observation, policy_id="blue")

                env.step(action)

            print(reward_sum)
        env.close()

if __name__ == "__main__":
    run(True, "multi_grid_world_championship_v0")

PROBLEM

Now, if you run it like that you’ll get this error:

 assert self.env and self.module

Closer examination shows the the module is None. Although I believe that there should be a default module used, MultiAgentRLModule?

If I add a specific marl_module_class, e.g. PPOTorchRLModule, I get this error:

AttributeError: 'MultiAgentRLModuleConfig' object has no attribute 'get_catalog'

I got this working with the old API, but I feel I’m so close to getting this to work with the new one. Anyone have any ideas?

Some additional info, MultiAgentEnvRunner gives this:

(PPO pid=10071)     assert self.env and self.module
(PPO pid=10071) AssertionError
(MultiAgentEnvRunner pid=10133) ENV: <OrderEnforcing<PettingZooEnv<rllib-multi-agent-env-v0>>>, MODULE: None