Preprocessor error on batches of observations

Another update: Based on other posts/issues (link1, link2, link3, and link4), I think the problem I’m having has something to-do with my environment’s observation space being a dict, with one item being “observations” and the other being “action_mask”.

Based on some of the comments at the above links, here’s what I’ve tried so far, none of which work:

  1. (baseline) The bare observation space is:
observation_space = Dict({
    "observations": Dict({
        "ob_a": Box(..., dtype=int), # length = 24
        "ob_b": Box(..., dtype=float) # length= 8
    })
    "action_mask": Box(..., dtype=int) # length =10
})
  1. Flatten “observations” so the the observation space looks like:
observation_space = Dict({
    "observations": Box(..., dtype=float), # length = 32
    "action_mask": Box(..., dtype=int)})
  1. Made all entries in “observations” and “action_mask” floats (vice “observations” being “floats” and action_mask being ints)
observation_space = Dict({
    "observations": Box(..., dtype=float),
    "action_mask": Box(..., dtype=float)})
  1. Using a custom model that is a modified version of the Ray’s action mask example. I’ve tried using this model with both #1 and #2 options above.
"""Model with action masking."""
# %% Imports
from __future__ import annotations

# Standard Library Imports
from typing import Any

# Third Party Imports
import torch
import torch.nn as nn
from gym.spaces import Dict, Space
from numpy import stack
from ray.rllib.models.torch.fcnet import FullyConnectedNetwork as TorchFC
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.utils.torch_utils import FLOAT_MIN
from torch import Tensor, tensor


# %% Class


class MyActionMaskModel(TorchModelV2, nn.Module):
    """Model that handles simple discrete action masking.

    Include all regular model configuration parameters (fcnet_hiddens,
    fcnet_activation, etc.) in model_config.

    Custom model parameters are set in model_config["custom_model_config"]. The
    only custom model parameter is "no_masking" (default False).
    """

    def __init__(
        self,
        obs_space: Space,
        action_space: Space,
        num_outputs: int,
        model_config: dict,
        name: str,
        **kwargs,
    ):
        """Initialize action masking model.

        Args:
            obs_space (`Space`): A gym space.
            action_space (`Space`): A gym space.
            num_outputs (`int`): Number of outputs of neural net. Should be the
                size of the flattened action space.
            model_config (`dict`): Model configuration. Required inputs are:
                {
                    "fcnet_hiddens" (`list[int]`): Fully connected hidden layers.
                }
            name (`str`): Name of model.

        To disable action masking, set:
            model_config["custom_model_config"]["no_masking] = True.
        """
        # Check that the observation space is a dict that contains "action_mask"
        # and "observations" as keys.
        orig_space = getattr(obs_space, "original_space", obs_space)
        assert isinstance(orig_space, Dict)
        assert "action_mask" in orig_space.spaces
        assert "observations" in orig_space.spaces

        # Boilerplate Torch stuff.
        TorchModelV2.__init__(
            self,
            obs_space,
            action_space,
            num_outputs,
            model_config,
            name,
            **kwargs,
        )
        nn.Module.__init__(self)

        # Build feed-forward layers
        self.internal_model = TorchFC(
            orig_space["observations"],
            action_space,
            num_outputs,
            model_config,
            name + "_internal",
        )

        last_layer_size = model_config["fcnet_hiddens"][-1]
        self.action_head = nn.Linear(last_layer_size, num_outputs)
        self.value_head = nn.Linear(last_layer_size, 1)

        # disable action masking --> will likely lead to invalid actions
        custom_config = model_config.get("custom_model_config", {})
        self.no_masking = False
        if "no_masking" in custom_config:
            self.no_masking = custom_config["no_masking"]

    def forward(
        self,
        input_dict: dict[dict],
        state: Any,
        seq_lens: Any,
    ) -> [Tensor, Any]:
        """Forward propagate observations through the model.

        Takes a `dict` as an argument with the only key being "obs", which is either
        a sample from the observation space or a list of samples from the observation
        space.

        Can input either a single observation or multiple observations. If using
        a single observation, the input is a dict[dict[dict]]]. If using
        multiple observations, the input is a dict[dict[list_of_dicts]].

        Args:
            input_dict (`dict`[`dict`]):
                {
                    "obs": {
                        "action_mask": `Tensor`,
                        "observations": `Tensor`,
                    }
                }
                or
                {
                    "obs": list[
                        {
                        "action_mask": `Tensor`,
                        "observations": `Tensor`
                        },
                        ...]
                }
            state (`Any`): _description_
            seq_lens (`Any`): _description_

        Returns:
            logits (`Tensor`): Logits in shape of (num_outputs, ).
            state (`Any`): _description_
        """
        # Extract the action mask and observations from the input dict and convert
        # to tensor, if necessary. Stack action masks and observations into larger
        # tensor if multiple obs are passed in. The action mask and observation
        # are different sizes depending on if multiple or single observations are
        # passed in. Convert tensors to floats if not already to input to torch
        # Linear layers
        # (https://pytorch.org/docs/stable/generated/torch.nn.Linear.html#torch.nn.Linear).
        if type(input_dict["obs"]) is list:
            # For multiple observations
            # action_mask is a [num_observations, len_mask] tensor
            # observation is a [num_observations, len_obs] tensor
            array_of_masks = stack(
                [a["action_mask"] for a in input_dict["obs"]], axis=0
            )
            action_mask = tensor(array_of_masks)
            array_of_obs = stack(
                [a["observations"] for a in input_dict["obs"]], axis=0
            )
            observation = tensor(array_of_obs).float()
        else:
            action_mask = input_dict["obs"]["action_mask"]
            observation = input_dict["obs"]["observations"].float()

        # Compute the unmasked logits.
        self.internal_model._features = (
            self.internal_model._hidden_layers.forward(observation)
        )
        # print(
        #     f"internal_model._features.size() = {self.internal_model._features.size()}"
        # )
        logits = self.action_head(self.internal_model._features)

        # If action masking is disabled, skip masking and return unmasked logits.
        # Otherwise, step into masking block.
        if self.no_masking is False:
            # Convert action_mask into a [0.0 || -inf]-type mask.
            inf_mask = torch.clamp(torch.log(action_mask), min=FLOAT_MIN)
            # print(f"logits.size() = {logits.size()}")
            # print(f"inf_mask.size() = {inf_mask.size()}")
            masked_logits = logits + inf_mask
            logits = masked_logits

        return logits, state

    def value_function(self) -> Tensor:
        """Get current value of value function.

        Returns:
            `Tensor[torch.float32]`: Value function value.
        """
        # get features and squeeze extra dimensions out.
        y = self.value_head(self.internal_model._features)
        y = y.squeeze(-1)
        return y
  1. Fully flattening the observation space, just to make sure the environment and constructor are not broken. This works, but I lose the ability to do action masking, which is critical to the environment. So this test was more of a sanity check than a reasonable path forward. Note that this test did not include the action masking model above, because that model relies on interfacing with a dict observation space.

The fact that I get the same errors regardless of the using the action masking model makes me think the error is solely in the interface of the environment.

Notably, the error is different for using the bare environment vice the others. For case #0, the error is “ValueError: could not broadcast input array from shape (8,) into shape (4,)”
But for the other cases, the error is ValueError: could not broadcast input array from shape (32,) into shape (10,).