Masked GTrXLNet

  • High: It blocks me to complete my task.

I am trying to incorporate action masking into the GTrXLNet example, and have built the below model and script based on the Ray examples (attention net, action masking) and these two forum posts (1, 2).

My intent is to build a simple FC model that incorporates action masking. The to use GTrXLNet, I would set use_attention = True when I configure the algorithm. If this is the wrong design approach, please let me know what the correct/preferred way to go about this.

My model runs when attention is disabled. When attention is enabled, it throws a shape error that seems to do with AttentionWrapper expecting an observation that includes both the real observation (obs["observations"]) and the action mask (obs["action_mask"]), but it instead being given just the real observation part. The AttentionWrapper is expecting a […, 4] sized tensor, but is being given a [32,2] sized tensor. The last dimensions should agree (according to the doc string in _unpack_obs) but do not. The 32 comes from GTrXLNet’s default head_dim (I’m pretty sure). The 2 is the size of the real “observations”, and the 4 is the size of “observations” + “action_mask”.

I think the cause of the error is in a mismatch between how I’ve configured the observation space in my main model vs the internal FCNet vs what AttentionWrapper / GTrXLNet expects.

In the code below, note that the environment observation space is
env.observation_space=Dict('action_mask': Box(0, 1, (2,), int64), 'observations': Box(0, 1, (2,), int64))

Custom env:

"""Action Mask Repeat After Me Env to use in testing."""
# %% Imports
# Third Party Imports
from gymnasium import Env
from gymnasium.spaces import Dict
from gymnasium.spaces.utils import flatten, flatten_space
from numpy import ones
from ray.rllib.examples.env.repeat_after_me_env import RepeatAfterMeEnv

class MaskRepeatAfterMe(Env):
    """RepeatAfterMeEnv with action masking.

    There are three options for mask_config:
        "viable_random": The action mask is a random sample from the action space,
            with the first action always available.
        "full_random": The action mask is a random sample from the action space.
        "off": All actions are available (the action mask is an array of ones).

    def __init__(self, config=None):
        """Instantiate MaskRepeatAfterMe."""
        self.internal_env = RepeatAfterMeEnv()
        self.observation_space = Dict(
                "observations": flatten_space(self.internal_env.observation_space),
                "action_mask": flatten_space(self.internal_env.action_space),
        self.action_space = self.internal_env.action_space
        if config is None:
            config = {}
        self.mask_config = config.get("mask_config", "viable_random")

    def reset(self, *, seed=None, options=None):
        """Reset env."""
        obs, info = self.internal_env.reset()
        new_obs = self._wrapObs(obs)
        self.last_obs = new_obs
        return new_obs, info

    def step(self, action):
        """Step env."""
        trunc = self._checkMaskViolation(action)
        obs, reward, done, _, info = self.internal_env.step(action)
        new_obs = self._wrapObs(obs)
        self.last_obs = new_obs
        return new_obs, reward, done, trunc, info

    def _wrapObs(self, unwrapped_obs):
        if self.mask_config in ["viable_random"]:
            mask = self.observation_space.spaces["action_mask"].sample()
            mask[0] = 1
        elif self.mask_config == "full_random":
            mask = self.observation_space.spaces["action_mask"].sample()
        elif self.mask_config == "off":
            mask = ones(self.observation_space.spaces["action_mask"].shape, dtype=int)

        wrapped_obs = {
            "observations": flatten(self.internal_env.observation_space, unwrapped_obs),
            "action_mask": mask,
        return wrapped_obs

    def _checkMaskViolation(self, action):
        flat_action = flatten(self.action_space, action)
        diff = self.last_obs["action_mask"] - flat_action
        if any([i < 0 for i in diff]):
            truncate = True
            print("mask violation")
            truncate = False

        return truncate

Model and test script

# %% Imports
# Standard Library Imports
import inspect
import os
from typing import Dict, List, Optional, Tuple, Union

# Third Party Imports
import gymnasium as gym
import ray
import ray.rllib.algorithms.ppo as ppo
import torch.nn as nn
from gymnasium.spaces.utils import flatten
from ray.rllib.examples.env.repeat_after_me_env import RepeatAfterMeEnv
from ray.rllib.models import ModelCatalog
from ray.rllib.models.modelv2 import ModelConfigDict, ModelV2
from ray.rllib.models.torch.attention_net import AttentionWrapper
from ray.rllib.models.torch.fcnet import FullyConnectedNetwork as TorchFC
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.utils.annotations import override
from ray.tune import registry
from torch import TensorType, reshape

class CustomAttentionWrapper(TorchModelV2, nn.Module):
    def __init__(
        obs_space: gym.spaces.Space,
        action_space: gym.spaces.Space,
        num_outputs: int,
        model_config: ModelConfigDict,
        name: str,
        orig_space = getattr(obs_space, "original_space", obs_space)
        assert isinstance(orig_space, gym.spaces.Dict)
        assert "action_mask" in orig_space.spaces
        assert "observations" in orig_space.spaces

        self.wrapped_obs_space = orig_space.spaces["observations"]


        self.internal_model = TorchFC(
            name=name + "_internal",

        self._value_out = None

    def forward(
        input_dict: dict[str, TensorType],
        state: Optional[list[TensorType]],
        seq_lens: TensorType,
    ) -> Tuple[TensorType, list[TensorType]]:
        # Remove action mask from: 'obs', 'new_obs', and 'obs_flat'.
        obs_only_dict = self.removeActionMask(input_dict)

        # Pass the observations and action mask to the internal model
        # and get the output and new state
        logits, new_state = self.internal_model(input_dict=obs_only_dict)

        # Mask the output
        action_mask = input_dict["obs"]["action_mask"]
        masked_logits = maskLogits(logits=logits, mask=action_mask)

        # Return the masked output and the new state
        return masked_logits, new_state

    def removeActionMask(
        self, input_dict: dict[str, TensorType]
    ) -> dict[str, TensorType]:
        """Remove the action mask from the input dict."""
        # Watch out for input_dict being a SampleBatch
        modified_input_dict = input_dict.copy()
        modified_input_dict["obs"] = input_dict["obs"]["observations"]
        modified_input_dict["obs_flat"] = flatten(
            self.wrapped_obs_space, modified_input_dict["obs"]
        if "new_obs" in modified_input_dict:
            # 'new_obs' is only present in the input dict when using attention wrapper
            modified_input_dict["new_obs"] = modified_input_dict["new_obs"][
                :, : modified_input_dict["obs"].shape[1]

        return modified_input_dict

    def get_initial_state(self) -> list[TensorType]:
        return self.internal_model.get_initial_state()

    def value_function(self) -> TensorType:
        return self.internal_model.value_function()

if __name__ == "__main__":
    env = MaskRepeatAfterMe()
    # env.observation_space=Dict('action_mask': Box(0, 1, (2,), int64), 'observations': Box(0, 1, (2,), int64))
    # env.action_space=Discrete(2)


    # register custom environments
    registry.register_env("MaskRepeatAfterMe", MaskRepeatAfterMe)
    ModelCatalog.register_custom_model("CustomAttentionWrapper", CustomAttentionWrapper)

    # Make config
    config = (
            env_config={"mask_config": "off"},
                "custom_model": "CustomAttentionWrapper",
                "fcnet_hiddens": [32, 2],  # last layer must be size of action space
                "use_attention": True,
        .resources(num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", 0)))

    # %% Build an train
    algo =


2023-12-08 11:55:25,254 ERROR -- Exception raised in creation task: The actor died because of an error raised in its creation task, ray::RolloutWorker.__init__() (pid=20734, ip=, actor_id=c74812efe1010add64e177ff01000000, repr=<ray.rllib.evaluation.rollout_worker._modify_class.<locals>.Class object at 0x7f80739847c0>)
  File "/home/usr/anaconda3/envs/punch/lib/python3.10/site-packages/ray/rllib/evaluation/", line 738, in __init__
  File "/home/usr/anaconda3/envs/punch/lib/python3.10/site-packages/ray/rllib/evaluation/", line 1985, in _update_policy_map
  File "/home/usr/anaconda3/envs/punch/lib/python3.10/site-packages/ray/rllib/evaluation/", line 2097, in _build_policy_map
    new_policy = create_policy_for_framework(
  File "/home/usr/anaconda3/envs/punch/lib/python3.10/site-packages/ray/rllib/utils/", line 142, in create_policy_for_framework
    return policy_class(observation_space, action_space, merged_config)
  File "/home/usr/anaconda3/envs/punch/lib/python3.10/site-packages/ray/rllib/algorithms/ppo/", line 64, in __init__
  File "/home/usr/anaconda3/envs/punch/lib/python3.10/site-packages/ray/rllib/policy/", line 1408, in _initialize_loss_from_dummy_batch
    actions, state_outs, extra_outs = self.compute_actions_from_input_dict(
  File "/home/usr/anaconda3/envs/punch/lib/python3.10/site-packages/ray/rllib/policy/", line 526, in compute_actions_from_input_dict
    return self._compute_action_helper(
  File "/home/usr/anaconda3/envs/punch/lib/python3.10/site-packages/ray/rllib/utils/", line 24, in wrapper
    return func(self, *a, **k)
  File "/home/usr/anaconda3/envs/punch/lib/python3.10/site-packages/ray/rllib/policy/", line 1159, in _compute_action_helper
    dist_inputs, state_out = self.model(input_dict, state_batches, seq_lens)
  File "/home/usr/anaconda3/envs/punch/lib/python3.10/site-packages/ray/rllib/models/", line 259, in __call__
    res = self.forward(restored, state or [], seq_lens)
  File "/home/usr/anaconda3/envs/punch/lib/python3.10/site-packages/ray/rllib/models/torch/", line 444, in forward
    self._features, memory_outs = self.gtrxl(input_dict, state, seq_lens)
  File "/home/usr/anaconda3/envs/punch/lib/python3.10/site-packages/ray/rllib/models/", line 247, in __call__
    restored["obs"] = restore_original_dimensions(
  File "/home/usr/anaconda3/envs/punch/lib/python3.10/site-packages/ray/rllib/models/", line 414, in restore_original_dimensions
    return _unpack_obs(obs, original_space, tensorlib=tensorlib)
  File "/home/usr/anaconda3/envs/punch/lib/python3.10/site-packages/ray/rllib/models/", line 448, in _unpack_obs
    raise ValueError(
ValueError: Expected flattened obs shape of [..., 4], got torch.Size([32, 2])