- High: It blocks me to complete my task.
I am trying to incorporate action masking into the GTrXLNet example, and have built the below model and script based on the Ray examples (attention net, action masking) and these two forum posts (1, 2).
My intent is to build a simple FC model that incorporates action masking. The to use GTrXLNet, I would set use_attention = True
when I configure the algorithm. If this is the wrong design approach, please let me know what the correct/preferred way to go about this.
My model runs when attention is disabled. When attention is enabled, it throws a shape error that seems to do with AttentionWrapper
expecting an observation that includes both the real observation (obs["observations"]
) and the action mask (obs["action_mask"]
), but it instead being given just the real observation part. The AttentionWrapper
is expecting a […, 4] sized tensor, but is being given a [32,2] sized tensor. The last dimensions should agree (according to the doc string in _unpack_obs
) but do not. The 32 comes from GTrXLNet
’s default head_dim
(I’m pretty sure). The 2 is the size of the real “observations”, and the 4 is the size of “observations” + “action_mask”.
I think the cause of the error is in a mismatch between how I’ve configured the observation space in my main model vs the internal FCNet
vs what AttentionWrapper
/ GTrXLNet
expects.
In the code below, note that the environment observation space is
env.observation_space=Dict('action_mask': Box(0, 1, (2,), int64), 'observations': Box(0, 1, (2,), int64))
Custom env:
"""Action Mask Repeat After Me Env to use in testing."""
# %% Imports
# Third Party Imports
from gymnasium import Env
from gymnasium.spaces import Dict
from gymnasium.spaces.utils import flatten, flatten_space
from numpy import ones
from ray.rllib.examples.env.repeat_after_me_env import RepeatAfterMeEnv
class MaskRepeatAfterMe(Env):
"""RepeatAfterMeEnv with action masking.
There are three options for mask_config:
"viable_random": The action mask is a random sample from the action space,
with the first action always available.
"full_random": The action mask is a random sample from the action space.
"off": All actions are available (the action mask is an array of ones).
"""
def __init__(self, config=None):
"""Instantiate MaskRepeatAfterMe."""
self.internal_env = RepeatAfterMeEnv()
self.observation_space = Dict(
{
"observations": flatten_space(self.internal_env.observation_space),
"action_mask": flatten_space(self.internal_env.action_space),
}
)
self.action_space = self.internal_env.action_space
if config is None:
config = {}
self.mask_config = config.get("mask_config", "viable_random")
def reset(self, *, seed=None, options=None):
"""Reset env."""
obs, info = self.internal_env.reset()
new_obs = self._wrapObs(obs)
self.last_obs = new_obs
return new_obs, info
def step(self, action):
"""Step env."""
trunc = self._checkMaskViolation(action)
obs, reward, done, _, info = self.internal_env.step(action)
new_obs = self._wrapObs(obs)
self.last_obs = new_obs
return new_obs, reward, done, trunc, info
def _wrapObs(self, unwrapped_obs):
if self.mask_config in ["viable_random"]:
mask = self.observation_space.spaces["action_mask"].sample()
mask[0] = 1
elif self.mask_config == "full_random":
mask = self.observation_space.spaces["action_mask"].sample()
elif self.mask_config == "off":
mask = ones(self.observation_space.spaces["action_mask"].shape, dtype=int)
wrapped_obs = {
"observations": flatten(self.internal_env.observation_space, unwrapped_obs),
"action_mask": mask,
}
return wrapped_obs
def _checkMaskViolation(self, action):
flat_action = flatten(self.action_space, action)
diff = self.last_obs["action_mask"] - flat_action
if any([i < 0 for i in diff]):
truncate = True
print("mask violation")
else:
truncate = False
return truncate
Model and test script
# %% Imports
# Standard Library Imports
import inspect
import os
from typing import Dict, List, Optional, Tuple, Union
# Third Party Imports
import gymnasium as gym
import ray
import ray.rllib.algorithms.ppo as ppo
import torch.nn as nn
from gymnasium.spaces.utils import flatten
from ray.rllib.examples.env.repeat_after_me_env import RepeatAfterMeEnv
from ray.rllib.models import ModelCatalog
from ray.rllib.models.modelv2 import ModelConfigDict, ModelV2
from ray.rllib.models.torch.attention_net import AttentionWrapper
from ray.rllib.models.torch.fcnet import FullyConnectedNetwork as TorchFC
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.utils.annotations import override
from ray.tune import registry
from torch import TensorType, reshape
class CustomAttentionWrapper(TorchModelV2, nn.Module):
def __init__(
self,
obs_space: gym.spaces.Space,
action_space: gym.spaces.Space,
num_outputs: int,
model_config: ModelConfigDict,
name: str,
):
orig_space = getattr(obs_space, "original_space", obs_space)
assert isinstance(orig_space, gym.spaces.Dict)
assert "action_mask" in orig_space.spaces
assert "observations" in orig_space.spaces
self.wrapped_obs_space = orig_space.spaces["observations"]
nn.Module.__init__(self)
super().__init__(
obs_space=obs_space,
action_space=action_space,
num_outputs=num_outputs,
model_config=model_config,
name=name,
)
self.internal_model = TorchFC(
obs_space=self.wrapped_obs_space,
action_space=action_space,
num_outputs=num_outputs,
model_config=model_config,
name=name + "_internal",
)
self._value_out = None
def forward(
self,
input_dict: dict[str, TensorType],
state: Optional[list[TensorType]],
seq_lens: TensorType,
) -> Tuple[TensorType, list[TensorType]]:
# Remove action mask from: 'obs', 'new_obs', and 'obs_flat'.
obs_only_dict = self.removeActionMask(input_dict)
# Pass the observations and action mask to the internal model
# and get the output and new state
logits, new_state = self.internal_model(input_dict=obs_only_dict)
# Mask the output
action_mask = input_dict["obs"]["action_mask"]
masked_logits = maskLogits(logits=logits, mask=action_mask)
# Return the masked output and the new state
return masked_logits, new_state
def removeActionMask(
self, input_dict: dict[str, TensorType]
) -> dict[str, TensorType]:
"""Remove the action mask from the input dict."""
# Watch out for input_dict being a SampleBatch
modified_input_dict = input_dict.copy()
modified_input_dict["obs"] = input_dict["obs"]["observations"]
modified_input_dict["obs_flat"] = flatten(
self.wrapped_obs_space, modified_input_dict["obs"]
)
if "new_obs" in modified_input_dict:
# 'new_obs' is only present in the input dict when using attention wrapper
modified_input_dict["new_obs"] = modified_input_dict["new_obs"][
:, : modified_input_dict["obs"].shape[1]
]
return modified_input_dict
@override(ModelV2)
def get_initial_state(self) -> list[TensorType]:
return self.internal_model.get_initial_state()
@override(ModelV2)
def value_function(self) -> TensorType:
return self.internal_model.value_function()
if __name__ == "__main__":
env = MaskRepeatAfterMe()
print(f"{env.observation_space=}")
print(f"{env.action_space=}")
# env.observation_space=Dict('action_mask': Box(0, 1, (2,), int64), 'observations': Box(0, 1, (2,), int64))
# env.action_space=Discrete(2)
ray.init(local_mode=True)
# register custom environments
registry.register_env("MaskRepeatAfterMe", MaskRepeatAfterMe)
ModelCatalog.register_custom_model("CustomAttentionWrapper", CustomAttentionWrapper)
# Make config
config = (
ppo.PPOConfig()
.environment(
"MaskRepeatAfterMe",
env_config={"mask_config": "off"},
)
.training(
gamma=0.99,
entropy_coeff=0.001,
num_sgd_iter=10,
vf_loss_coeff=1e-5,
model={
"custom_model": "CustomAttentionWrapper",
"fcnet_hiddens": [32, 2], # last layer must be size of action space
"use_attention": True,
},
)
.framework("torch")
.rollouts(num_envs_per_worker=20)
.resources(num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", 0)))
)
# %% Build an train
algo = config.build()
algo.train()
Error:
2023-12-08 11:55:25,254 ERROR actor.py:970 -- Exception raised in creation task: The actor died because of an error raised in its creation task, ray::RolloutWorker.__init__() (pid=20734, ip=172.19.84.85, actor_id=c74812efe1010add64e177ff01000000, repr=<ray.rllib.evaluation.rollout_worker._modify_class.<locals>.Class object at 0x7f80739847c0>)
File "/home/usr/anaconda3/envs/punch/lib/python3.10/site-packages/ray/rllib/evaluation/rollout_worker.py", line 738, in __init__
self._update_policy_map(policy_dict=self.policy_dict)
File "/home/usr/anaconda3/envs/punch/lib/python3.10/site-packages/ray/rllib/evaluation/rollout_worker.py", line 1985, in _update_policy_map
self._build_policy_map(
File "/home/usr/anaconda3/envs/punch/lib/python3.10/site-packages/ray/rllib/evaluation/rollout_worker.py", line 2097, in _build_policy_map
new_policy = create_policy_for_framework(
File "/home/usr/anaconda3/envs/punch/lib/python3.10/site-packages/ray/rllib/utils/policy.py", line 142, in create_policy_for_framework
return policy_class(observation_space, action_space, merged_config)
File "/home/usr/anaconda3/envs/punch/lib/python3.10/site-packages/ray/rllib/algorithms/ppo/ppo_torch_policy.py", line 64, in __init__
self._initialize_loss_from_dummy_batch()
File "/home/usr/anaconda3/envs/punch/lib/python3.10/site-packages/ray/rllib/policy/policy.py", line 1408, in _initialize_loss_from_dummy_batch
actions, state_outs, extra_outs = self.compute_actions_from_input_dict(
File "/home/usr/anaconda3/envs/punch/lib/python3.10/site-packages/ray/rllib/policy/torch_policy_v2.py", line 526, in compute_actions_from_input_dict
return self._compute_action_helper(
File "/home/usr/anaconda3/envs/punch/lib/python3.10/site-packages/ray/rllib/utils/threading.py", line 24, in wrapper
return func(self, *a, **k)
File "/home/usr/anaconda3/envs/punch/lib/python3.10/site-packages/ray/rllib/policy/torch_policy_v2.py", line 1159, in _compute_action_helper
dist_inputs, state_out = self.model(input_dict, state_batches, seq_lens)
File "/home/usr/anaconda3/envs/punch/lib/python3.10/site-packages/ray/rllib/models/modelv2.py", line 259, in __call__
res = self.forward(restored, state or [], seq_lens)
File "/home/usr/anaconda3/envs/punch/lib/python3.10/site-packages/ray/rllib/models/torch/attention_net.py", line 444, in forward
self._features, memory_outs = self.gtrxl(input_dict, state, seq_lens)
File "/home/usr/anaconda3/envs/punch/lib/python3.10/site-packages/ray/rllib/models/modelv2.py", line 247, in __call__
restored["obs"] = restore_original_dimensions(
File "/home/usr/anaconda3/envs/punch/lib/python3.10/site-packages/ray/rllib/models/modelv2.py", line 414, in restore_original_dimensions
return _unpack_obs(obs, original_space, tensorlib=tensorlib)
File "/home/usr/anaconda3/envs/punch/lib/python3.10/site-packages/ray/rllib/models/modelv2.py", line 448, in _unpack_obs
raise ValueError(
ValueError: Expected flattened obs shape of [..., 4], got torch.Size([32, 2])