1. Severity of the issue: (select one)
None: I’m just curious or want clarification.
2. Environment:
- Ray version: Python 3.11.13
- Python version: 2.48.0
- OS: Ubuntu
- Cloud/Infrastructure: N/A
Hello, all. I’d been working on a project involving action masking, and I thought I would try to implement it using connectors, rather than through a custom module (as was done in the examples). The aim, here, is to allow for a bit more modularity when including action masks, and to better understand how ConnectorV2 is meant to be used.
I thought I had a pretty straightforward solution - strip the action mask from the observation in the env_to_module_connector
, put it on the batch using add_batch_item
, then retrieve it in the module_to_env_connector
and adjust action sampling accordingly. Unfortunately, the batch
getting fed into the latter comes directly from the module, and is thus missing the extra information (the action mask) that I appended to it.
Am I going about this the wrong way? A pair of connectors seems like a natural solution, here. Is there a better way to store additional information that was gathered in the env to module pipeline?
My (incomplete) code is below, as a reference:
ActionMaskEnvConnector
from typing import Any, Collection, Dict, List, Optional
import gymnasium as gym
from ray.rllib.connectors.connector_v2 import ConnectorV2
from ray.rllib.core.rl_module.rl_module import RLModule
from ray.rllib.utils.annotations import override
from ray.rllib.utils.typing import AgentID, EpisodeType
from ray.util.annotations import PublicAPI
from gymnasium.wrappers.vector import DictInfoToList
ACTION_MASK = "action_mask"
OBSERVATIONS = "observations"
@PublicAPI(stability="alpha")
class ActionMaskEnvConnector(ConnectorV2):
@override(ConnectorV2)
def recompute_output_observation_space(
self,
input_observation_space,
input_action_space,
) -> gym.Space:
# Change our observation space according to the given stacking settings.
'''print("Recomputing observation space!")
print(input_observation_space)'''
if isinstance(input_observation_space, DictInfoToList):
input_observation_space = input_observation_space.observation_space
if self._multi_agent:
ret = {}
for agent_id, obs_space in input_observation_space.spaces.items():
ret[agent_id] = obs_space[OBSERVATIONS]
return gym.spaces.Dict(ret)
else:
return input_observation_space[OBSERVATIONS]
def __init__(
self,
input_observation_space: Optional[gym.Space] = None,
input_action_space: Optional[gym.Space] = None,
*,
multi_agent: bool = False,
agent_ids: Optional[Collection[AgentID]] = None,
**kwargs,
):
self._input_obs_base_struct = None
self._multi_agent = multi_agent
self._agent_ids = agent_ids
super().__init__(input_observation_space, input_action_space, **kwargs)
@override(ConnectorV2)
def __call__(
self,
*,
rl_module: RLModule,
batch: Dict[str, Any],
episodes: List[EpisodeType],
explore: Optional[bool] = None,
shared_data: Optional[dict] = None,
**kwargs,
) -> Any:
for sa_episode in self.single_agent_episode_iterator(
episodes, agents_that_stepped_only=True
):
last_obs = sa_episode.get_observations(-1)
if self._multi_agent:
if (
self._agent_ids is not None
and sa_episode.agent_id not in self._agent_ids
):
new_obs = last_obs # Agent wasn't part of this step
else:
new_obs = last_obs[OBSERVATIONS]
a_mask = last_obs[ACTION_MASK]
else:
new_obs = last_obs[OBSERVATIONS]
a_mask = last_obs[ACTION_MASK]
# Write new observation directly back into the episode.
sa_episode.set_observations(at_indices=-1, new_data=new_obs)
if (a_mask is not None): # Store the action mask for action postproc
self.add_batch_item(batch=batch, column=ACTION_MASK, item_to_add=a_mask, single_agent_episode=sa_episode)
# We set the Episode's observation space to ours so that we can safely
# set the last obs to the new value (without causing a space mismatch
# error).
sa_episode.observation_space = self.observation_space
print("Env Connector Called")
print(batch)
return batch
ActionMaskModuleConnector
# @title ActionMaskModuleConnector
# @title ActionMaskEnvConnector
from typing import Any, Collection, Dict, List, Optional
import gymnasium as gym
from ray.rllib.connectors.connector_v2 import ConnectorV2
from ray.rllib.core.rl_module.rl_module import RLModule
from ray.rllib.utils.annotations import override
from ray.rllib.utils.typing import AgentID, EpisodeType
from ray.util.annotations import PublicAPI
from gymnasium.wrappers.vector import DictInfoToList
@PublicAPI(stability="alpha")
class ActionMaskModuleConnector(ConnectorV2):
@override(ConnectorV2)
def __call__(
self,
*,
rl_module: RLModule,
batch: Dict[str, Any],
episodes: List[EpisodeType],
explore: Optional[bool] = None,
shared_data: Optional[dict] = None,
**kwargs,
) -> Any:
print("Module connector called")
print(batch)
raise Exception()
return batch
Test Code
# @title config
import numpy as np
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.core.rl_module.rl_module import RLModuleSpec
from ray.rllib.examples.envs.classes.action_mask_env import ActionMaskEnv
config = (
PPOConfig()
.environment(
env=ActionMaskEnv,
env_config={
"action_space": Discrete(100),
"observation_space": Box(-1.0, 1.0, (5,)),
},
)
.env_runners(
num_env_runners=0,
num_envs_per_env_runner=1,
env_to_module_connector=ActionMaskEnvConnector, # Strips the action mask from the input observations
module_to_env_connector=ActionMaskModuleConnector,
)
.training(
train_batch_size=200,
)
)
algo = config.build_algo()
algo.step()