Input to TorchModelV2 forward method inconsistent,

I want implement action masking for a simple environment with a dictionary observation space using the ApexDQN algorithm.

The input_dict given to the forward method of my custom model sometimes is a SampleBatch and sometimes just a dictionary (with the exact same info the sample batch would have).

Notice I have some (suboptimal) code that iterates over the rows of a sample batch and flattens them. This code does work for PPO, where the input_dict is always a SampleBatch.

What would be the best way to implement action masking for the apex DQN model?

import numpy as np
import ray

from ray.rllib.algorithms.apex_dqn.apex_dqn import ApexDQNConfig
from ray.tune.registry import register_env
import gymnasium
from gymnasium.spaces import Box, Dict, Discrete

from gymnasium.spaces.utils import flatten_space, flatten
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.models.torch.fcnet import FullyConnectedNetwork as TorchFC
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.torch_utils import FLOAT_MIN

torch, nn = try_import_torch()

# copy pasted from rllib/examples/models/
class TorchActionMaskModel(TorchModelV2, nn.Module):
    """PyTorch version of above ActionMaskingModel."""

    def __init__(
        orig_space = getattr(obs_space, "original_space", obs_space)
        assert (isinstance(orig_space, Dict)
                and "action_mask" in orig_space.spaces
                and "actual_obs" in orig_space.spaces)

        self.orig_state_space = orig_space["actual_obs"]
        TorchModelV2.__init__(self, obs_space, action_space, num_outputs,
                              model_config, name, **kwargs)

        self.internal_model = TorchFC(
            name + "_internal",

    def forward(self, input_dict, state, seq_lens):
        # Extract the available actions tensor from the observation.
        action_mask = input_dict["obs"]["action_mask"]

        data = []
        for row in input_dict.rows():  # this code only works when input_dict is a SampleBatch
            flattened_sample = flatten(self.orig_state_space,

        obs = torch.tensor(data)
        # Compute the unmasked logits.
        logits, _ = self.internal_model({"obs": obs})
        inf_mask = torch.clamp(torch.log(action_mask), min=FLOAT_MIN)
        return logits + inf_mask, state

    def value_function(self):
        return self.internal_model.value_function()

class MyEnv(gymnasium.Env):

    metadata = {"render.modes": ["human"]}

    def __init__(self):
        super(MyEnv, self).__init__()

        self.actions = 4

        self.action_space = Discrete(self.actions)
        self.observation_space = Dict({
            "action_mask": Box(0, 1, shape=(self.actions, )),
                "obs1": Box(low=-np.inf, high=np.inf, shape=(10, 10), dtype=np.float32),
                "obs2": Box(low=-np.inf, high=np.inf, shape=(10, 10), dtype=np.float32),

    def reset(self, *, seed=None, options=None):
        return self._make_obs(), {}

    def step(self, action):
        return self._make_obs(), 0, False, False, {}

    def _make_obs(self):
        return {
            "action_mask": np.array([1.0] * self.actions),
            "actual_obs": {
                "obs1": np.zeros((10, 10), dtype=np.float32),
                "obs2": np.zeros((10, 10), dtype=np.float32)

def main():

    select_env = "env-v1"
    register_env(select_env, lambda config: MyEnv())

    config = ApexDQNConfig().framework('torch') \
                model = {
                    "custom_model": TorchActionMaskModel,
                    "no_final_linear": False
                }, train_batch_size=32,
            ) \
    algo = 
    for _ in range(5):

if __name__ == "__main__":

Hello @user777888,

This post may be of some use for you.

1 Like

The issue I’m facing uniquely comes up for dictionary action spaces, so I’m not sure how the attempt above would help. The docs already provide an action masking example for a DQN, but it only works for simple state spaces (i.e. not a dict).

Action masks are passed as if they are part of the observations but they should not be used as part of the observation for the agent. Hence, we have to extract the actual observations and pass that to the internal_model in the forward function. The actual observations is a dictionary that needs to be flattened manually (normally you would use obs_flat, but this includes the action mask). However, this flattening as implemented above is only possible when the input_dict is a SampleBatch. But for some reason it sometimes isn’t? In the case of PPO, it is always a SampleBatch and the code works.

Hi @user777888,

I am not sure why it is sometimes one and sometimes the other. If I had to venture a gues I would suspect it is based on whether the observation is coming during the rollout sample phase versus the replay buffer in the training phase.

This is how the observations are flattened during collection. What I would probably do is report am issue and as a temporary workaround add something like this in forward.

if  isinstance(..., dict) :
   input = tree.flatten(...) 
elif isinstance(..., SampleBatch):
  input = ... 
  raise ValueError("Unexpected input type: {type(...)}") 
1 Like