Using Connectors to store, retrieve, and apply an action mask?

1. Severity of the issue: (select one)
None: I’m just curious or want clarification.

2. Environment:

  • Ray version: Python 3.11.13
  • Python version: 2.48.0
  • OS: Ubuntu
  • Cloud/Infrastructure: N/A

Hello, all. I’d been working on a project involving action masking, and I thought I would try to implement it using connectors, rather than through a custom module (as was done in the examples). The aim, here, is to allow for a bit more modularity when including action masks, and to better understand how ConnectorV2 is meant to be used.

I thought I had a pretty straightforward solution - strip the action mask from the observation in the env_to_module_connector, put it on the batch using add_batch_item, then retrieve it in the module_to_env_connector and adjust action sampling accordingly. Unfortunately, the batch getting fed into the latter comes directly from the module, and is thus missing the extra information (the action mask) that I appended to it.

Am I going about this the wrong way? A pair of connectors seems like a natural solution, here. Is there a better way to store additional information that was gathered in the env to module pipeline?

My (incomplete) code is below, as a reference:


ActionMaskEnvConnector
from typing import Any, Collection, Dict, List, Optional

import gymnasium as gym

from ray.rllib.connectors.connector_v2 import ConnectorV2
from ray.rllib.core.rl_module.rl_module import RLModule
from ray.rllib.utils.annotations import override
from ray.rllib.utils.typing import AgentID, EpisodeType
from ray.util.annotations import PublicAPI
from gymnasium.wrappers.vector import DictInfoToList

ACTION_MASK = "action_mask"
OBSERVATIONS = "observations"

@PublicAPI(stability="alpha")
class ActionMaskEnvConnector(ConnectorV2):
    @override(ConnectorV2)
    def recompute_output_observation_space(
        self,
        input_observation_space,
        input_action_space,
    ) -> gym.Space:
        # Change our observation space according to the given stacking settings.
        '''print("Recomputing observation space!")
        print(input_observation_space)'''
        if isinstance(input_observation_space, DictInfoToList):
          input_observation_space = input_observation_space.observation_space
        if self._multi_agent:
            ret = {}
            for agent_id, obs_space in input_observation_space.spaces.items():
                ret[agent_id] = obs_space[OBSERVATIONS]
            return gym.spaces.Dict(ret)
        else:
            return input_observation_space[OBSERVATIONS]

    def __init__(
        self,
        input_observation_space: Optional[gym.Space] = None,
        input_action_space: Optional[gym.Space] = None,
        *,
        multi_agent: bool = False,
        agent_ids: Optional[Collection[AgentID]] = None,
        **kwargs,
    ):
        self._input_obs_base_struct = None
        self._multi_agent = multi_agent
        self._agent_ids = agent_ids

        super().__init__(input_observation_space, input_action_space, **kwargs)

    @override(ConnectorV2)
    def __call__(
        self,
        *,
        rl_module: RLModule,
        batch: Dict[str, Any],
        episodes: List[EpisodeType],
        explore: Optional[bool] = None,
        shared_data: Optional[dict] = None,
        **kwargs,
    ) -> Any:
        for sa_episode in self.single_agent_episode_iterator(
            episodes, agents_that_stepped_only=True
        ):
            last_obs = sa_episode.get_observations(-1)

            if self._multi_agent:
                if (
                    self._agent_ids is not None
                    and sa_episode.agent_id not in self._agent_ids
                ):
                    new_obs = last_obs # Agent wasn't part of this step
                else:
                    new_obs = last_obs[OBSERVATIONS]
                    a_mask = last_obs[ACTION_MASK]
            else:
                new_obs = last_obs[OBSERVATIONS]
                a_mask = last_obs[ACTION_MASK]

            # Write new observation directly back into the episode.
            sa_episode.set_observations(at_indices=-1, new_data=new_obs)
            if (a_mask is not None): # Store the action mask for action postproc
              self.add_batch_item(batch=batch, column=ACTION_MASK, item_to_add=a_mask, single_agent_episode=sa_episode)
            #  We set the Episode's observation space to ours so that we can safely
            #  set the last obs to the new value (without causing a space mismatch
            #  error).
            sa_episode.observation_space = self.observation_space
        print("Env Connector Called")
        print(batch)
        return batch
ActionMaskModuleConnector
# @title ActionMaskModuleConnector
# @title ActionMaskEnvConnector
from typing import Any, Collection, Dict, List, Optional

import gymnasium as gym

from ray.rllib.connectors.connector_v2 import ConnectorV2
from ray.rllib.core.rl_module.rl_module import RLModule
from ray.rllib.utils.annotations import override
from ray.rllib.utils.typing import AgentID, EpisodeType
from ray.util.annotations import PublicAPI
from gymnasium.wrappers.vector import DictInfoToList


@PublicAPI(stability="alpha")
class ActionMaskModuleConnector(ConnectorV2):
    @override(ConnectorV2)
    def __call__(
        self,
        *,
        rl_module: RLModule,
        batch: Dict[str, Any],
        episodes: List[EpisodeType],
        explore: Optional[bool] = None,
        shared_data: Optional[dict] = None,
        **kwargs,
    ) -> Any:
        print("Module connector called")
        print(batch)
        raise Exception()
        return batch
Test Code
# @title config
import numpy as np
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.core.rl_module.rl_module import RLModuleSpec

from ray.rllib.examples.envs.classes.action_mask_env import ActionMaskEnv

config = (
        PPOConfig()
        .environment(
            env=ActionMaskEnv,
            env_config={
                "action_space": Discrete(100),
                "observation_space": Box(-1.0, 1.0, (5,)),
            },
        )
        .env_runners(
            num_env_runners=0,
            num_envs_per_env_runner=1,
            env_to_module_connector=ActionMaskEnvConnector, # Strips the action mask from the input observations
            module_to_env_connector=ActionMaskModuleConnector,
        )
        .training(
            train_batch_size=200,
        )
    )
algo = config.build_algo()
algo.step()

Update: I hadn’t noticed the shared_data field in the function signature, but it seems to be exactly what I’m after. Would changing this line:

self.add_batch_item(batch=batch, column=ACTION_MASK, item_to_add=a_mask, single_agent_episode=sa_episode)

to this line:

self.add_batch_item(batch=shared_data, column=ACTION_MASK, item_to_add=a_mask, single_agent_episode=sa_episode)

be the intended way to solve my problem?

One thing that worries me is that PPO takes a direct forward pass of the module (via self.module.forward_train(batch) to obtain curr_action_dist. Should I take any special measures to mitigate the discrepancy between the ‘true’ action distribution and the one that the model outputs? I can see certain cases in which that discrepancy would impede training. For instance, when action probabilities from logits are [0.05, 0.9, 0.05], but action 1 is masked, action 0 would seem much less probable than it really is, and action 1’s probability might get treated as a viable lever for increasing/decreasing action 0’s probability when it really isn’t.


In any case, here’s my action masking code, implemented using Connectors. It runs and learns when set to play Tic Tac Toe against a heuristic model.

ActionMaskEnvConnector
from typing import Any, Collection, Dict, List, Optional

import gymnasium as gym

from ray.rllib.connectors.connector_v2 import ConnectorV2
from ray.rllib.core.rl_module.rl_module import RLModule
from ray.rllib.utils.annotations import override
from ray.rllib.utils.typing import AgentID, EpisodeType
from ray.util.annotations import PublicAPI
from gymnasium.wrappers.vector import DictInfoToList

ACTION_MASK = "action_mask"
OBSERVATIONS = "observations"

@PublicAPI(stability="alpha")
class ActionMaskEnvConnector(ConnectorV2):
    @override(ConnectorV2)
    def recompute_output_observation_space(
        self,
        input_observation_space,
        input_action_space,
    ) -> gym.Space:
        # Change our observation space according to the given stacking settings.
        '''
        print("Recomputing observation space!")
        print(input_observation_space) #'''
        if isinstance(input_observation_space, DictInfoToList):
          input_observation_space = input_observation_space.observation_space
        if self._multi_agent:
            ret = {}
            for agent_id, obs_space in input_observation_space.spaces.items():
                ret[agent_id] = obs_space[OBSERVATIONS]
            return gym.spaces.Dict(ret)
        else:
            return input_observation_space[OBSERVATIONS]

    def __init__(
        self,
        input_observation_space: Optional[gym.Space] = None,
        input_action_space: Optional[gym.Space] = None,
        *,
        multi_agent: bool = False,
        agent_ids: Optional[Collection[AgentID]] = None,
        **kwargs,
    ):
        self._input_obs_base_struct = None
        self._multi_agent = multi_agent
        self._agent_ids = agent_ids

        super().__init__(input_observation_space, input_action_space, **kwargs)

    @override(ConnectorV2)
    def __call__(
        self,
        *,
        rl_module: RLModule,
        batch: Dict[str, Any],
        episodes: List[EpisodeType],
        explore: Optional[bool] = None,
        shared_data: Optional[dict] = None,
        **kwargs,
    ) -> Any:
        for sa_episode in self.single_agent_episode_iterator(
            episodes, agents_that_stepped_only=True
        ):
            last_obs = sa_episode.get_observations(-1)

            if self._multi_agent:
                if (
                    self._agent_ids is not None
                    and sa_episode.agent_id not in self._agent_ids
                ):
                    new_obs = last_obs # Agent wasn't part of this step
                else:
                    new_obs = last_obs[OBSERVATIONS]
                    a_mask = last_obs[ACTION_MASK]
            else:
                new_obs = last_obs[OBSERVATIONS]
                a_mask = last_obs[ACTION_MASK]

            # Write new observation directly back into the episode.
            sa_episode.set_observations(at_indices=-1, new_data=new_obs)
            if (a_mask is not None): # Store the action mask for action postproc
                self.add_batch_item(batch=shared_data, column=ACTION_MASK, item_to_add=a_mask, single_agent_episode=sa_episode)
            #  We set the Episode's observation space to ours so that we can safely
            #  set the last obs to the new value (without causing a space mismatch
            #  error).
            sa_episode.observation_space = self.observation_space
        return batch
ActionMaskModuleConnector
from typing import Any, Collection, Dict, List, Optional
import torch

import gymnasium as gym

from ray.rllib.connectors.connector_v2 import ConnectorV2
from ray.rllib.core.rl_module.rl_module import RLModule
from ray.rllib.core.columns import Columns
from ray.rllib.utils.annotations import override
from ray.rllib.utils.typing import AgentID, EpisodeType
from ray.rllib.utils.torch_utils import FLOAT_MIN
from ray.util.annotations import PublicAPI
from gymnasium.wrappers.vector import DictInfoToList


@PublicAPI(stability="alpha")
class ActionMaskModuleConnector(ConnectorV2):
    def masked_sample(self, logits, mask):
      logits, mask = torch.tensor(logits), torch.tensor(mask)
      inf_mask = torch.clamp(torch.log(mask), min=FLOAT_MIN)
      logits_masked = logits + inf_mask
      cat = torch.distributions.Categorical(logits=logits_masked)
      action = cat.sample()
      return action, cat.log_prob(action)

    @override(ConnectorV2)
    def __call__(
        self,
        *,
        rl_module: RLModule,
        batch: Dict[str, Any],
        episodes: List[EpisodeType],
        explore: Optional[bool] = None,
        shared_data: Optional[dict] = None,
        **kwargs,
    ) -> Any:
        if (Columns.ACTION_DIST_INPUTS in batch):
          for k in batch[Columns.ACTION_DIST_INPUTS].keys():
              logits = batch[Columns.ACTION_DIST_INPUTS][k][-1]
              mask = shared_data[ACTION_MASK][k][-1]
              new_action, new_logp = self.masked_sample(logits, mask)
              batch[Columns.ACTIONS][k][-1] = new_action
              batch[Columns.ACTION_LOGP][k][-1] = new_logp
        return batch
Test Code
import numpy as np
from gymnasium.spaces import Box, Discrete
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.core.rl_module.rl_module import RLModuleSpec

from ray.rllib.examples.envs.classes.action_mask_env import ActionMaskEnv

config = (
        PPOConfig()
        .environment(
            env=ActionMaskEnv,
            env_config={
                "action_space": Discrete(100),
                "observation_space": Box(-1.0, 1.0, (5,)),
            },
        )
        .env_runners(
            num_env_runners=0,
            num_envs_per_env_runner=1,
            env_to_module_connector=ActionMaskEnvConnector, # Strips the action mask from the input observations
            module_to_env_connector=ActionMaskModuleConnector,
        )
        .training(
            train_batch_size=200,
        )
    )
algo = config.build_algo()
algo.step()