Is any multi discrete action example for PPO or other algorithms?

I think an example of a custom model using a MultiDiscrete action space would be useful. It’s not apparent from the current examples how to translate typical Box logits into MultiDiscrete. For example, here is my own in-progress action masking class (based on the rllib example).

from typing import Any
import torch
import torch.nn as nn
from gym.spaces import Dict, Space
from gym.spaces.utils import flatten_space
from numpy import argmax, int64, ndarray, stack, zeros
from ray.rllib.models.torch.fcnet import FullyConnectedNetwork as TorchFC
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.utils.torch_utils import FLOAT_MIN
from torch import Tensor, tensor

class MyActionMaskModel(TorchModelV2, nn.Module):
    """Model that handles simple discrete action masking.

    Include all regular model configuration parameters (fcnet_hiddens,
    fcnet_activation, etc.) in model_config.

    def __init__(
        obs_space: Space,
        action_space: Space,
        num_outputs: int,
        model_config: dict,
        name: str,
        """Initialize action masking model.

            obs_space (`Space`): A gym space.
            action_space (`Space`): A gym space.
            num_outputs (`int`): Number of outputs of neural net.
            model_config (`dict`): Model configuration. Required inputs are:
                    "fcnet_hiddens" (`list[int]`): Fully connected hidden layers.
            name (`str`): Name of model.
        # Check that the observation space is a dict that contains "action_mask"
        # and "observations" as keys.
        orig_space = getattr(obs_space, "original_space", obs_space)
        assert isinstance(orig_space, Dict)
        assert "action_mask" in orig_space.spaces
        assert "observations" in orig_space.spaces

        # Get number of actions each agent can make (assumes all agents can make)
        # the same number of actions). Then calculate number of outputs from hidden
        # layers of NN. The number of outputs from the NN is the size of the flattened
        # MultiDiscrete action space.
        # Examples:
        #   - If the MultiDiscrete space is [2, 2], then num_inputs = 4.
        #   - If the MultiDiscrete space is [3, 3, 3], then num_inputs = 9.
        self.num_actions = action_space.nvec[0]
        action_space_flat = flatten_space(action_space)
        # This overwrites the input value num_outputs
        num_outputs = action_space_flat.shape[0]

        # Boilerplate Torch stuff.

        # Build feed-forward layers
        self.internal_model = TorchFC(
            name + "_internal",

        last_layer_size = model_config["fcnet_hiddens"][-1]
        self.action_head = nn.Linear(last_layer_size, num_outputs)
        self.value_head = nn.Linear(last_layer_size, 1)

        self.no_masking = False

    def forward(
        input_dict: dict[dict],
        state: Any,
        seq_lens: Any,
    ) -> [Tensor, Any]:
        """Forward propagate observations through the model.

        Takes a `dict` as an argument with the only key being "obs", which is either
        a sample from the observation space or a list of samples from the observation

        Can input either a single observation or multiple observations. If using
        a single observation, the input is a dict[dict[dict]]]. If using
        multiple observations, the input is a dict[dict[list_of_dicts]].

            input_dict (`dict`[`dict`]):
                    "obs": {
                        "action_mask": `ndarray` | `Tensor`,
                        "observations": `ndarray` | `Tensor`,
                    "obs": list[
                        "action_mask": `ndarray`,
                        "observations": `ndarray`
            state (`Any`): _description_
            seq_lens (`Any`): _description_

            out (`Tensor`): Log(SoftMax()) of action probabilities.
            state (`Any`): _description_
        # Extract the action mask and observations from the input dict and convert
        # to tensor, if necessary. Stack action masks and observations into larger
        # tensor if multiple obs are passed in. The action mask and observation
        # are different sizes depending on if multiple or single observations are
        # passed in.
        if type(input_dict["obs"]) is list:
            # For multiple observations
            # action_mask is a [num_observations, len_mask] tensor
            # observation is a [num_observations, len_obs] tensor
            array_of_masks = stack(
                [a["action_mask"] for a in input_dict["obs"]], axis=0
            action_mask = tensor(array_of_masks)
            array_of_obs = stack(
                [a["observations"] for a in input_dict["obs"]], axis=0
            observation = tensor(array_of_obs).float()
        elif type(input_dict["obs"]["action_mask"]) is ndarray:
            # For single observations in numpy dtype
            # action_mask is a [len_mask] tensor
            # observation is a [len_obs] tensor
            action_mask = tensor(input_dict["obs"]["action_mask"])
            observation = tensor(input_dict["obs"]["observations"]).float()
        elif type(input_dict["obs"]["action_mask"]) is Tensor:
            # For single observations in Tensor dtype
            action_mask = input_dict["obs"]["action_mask"]
            observation = input_dict["obs"]["observations"]

        # Compute the unmasked logits.
        self.internal_model._features = (
        # print(
        #     f"internal_model._features.size() = {self.internal_model._features.size()}"
        # )
        actions = self.action_head(self.internal_model._features)

        # If action masking is disabled, skip masking and return unmasked actions.
        # Otherwise, step into masking block.
        if self.no_masking is False:
            # Convert action_mask into a [0.0 || -inf]-type mask.
            inf_mask = torch.clamp(torch.log(action_mask), min=FLOAT_MIN)
            # print(f"actions.size() = {actions.size()}")
            # print(f"inf_mask.size() = {inf_mask.size()}")
            masked_actions = actions + inf_mask
            actions = masked_actions

        # # Normalize outputs
        out = nn.functional.log_softmax(actions, dim=-1)

        return out, state

    def value_function(self) -> Tensor:
        """Get current value of value function.

            `Tensor[torch.float32]`: Value function value.
        # get features and squeeze extra dimensions out.
        y = self.value_head(self.internal_model._features)
        y = y.squeeze(-1)
        return y

And here is a short test with a random environment with a MultiDiscreteaction space.

from ray.rllib.examples.env.random_env import RandomEnv
from gym.spaces import Box, Dict, MultiDiscrete
from numpy import int64
from torch import tensor
import MyActionModel

rand_env = RandomEnv(
        "observation_space": Dict(
                "observations": Box(-1.0, 1.0, shape=(2,)),
                "action_mask": Box(
                    shape=(3 * 3,),
        "action_space": MultiDiscrete([3, 3, 3]),

rand_model = MyActionMaskModel(
    model_config={"fcnet_hiddens": [10, 5]},
    num_outputs=3 * 3,

obs_sample = rand_env.observation_space.sample()
obs_sample["observations"] = tensor(obs_sample["observations"])
obs_sample["action_mask"] = tensor(obs_sample["action_mask"])
obs_sample = {"obs": obs_sample}
print(f"action_mask = {obs_sample['obs']['action_mask']}")
# prints action_mask = tensor([1, 1, 0, 0, 1, 1, 1, 1, 0])

[md_actions, _] = rand_model.forward(obs_sample, None, None)
print(f"rand model actions = {md_actions}")
# prints rand model actions = tensor([-2.3527e+00, -1.7116e+00, -3.4000e+38, #-3.4000e+38, -1.8138e+00,
#      -2.3707e+00, -1.4054e+00, -1.5024e+00, -3.4000e+38],
#       grad_fn=<LogSoftmaxBackward0>)
    f"Calculated actions in action_space? {rand_env.action_space.contains(md_actions)}"
# prints Calculated actions in action_space? False

Even though the action_space in both the environment and model is set to MultiDiscrete, the output of model.forward() is not in action_space (as can be seen by the last line in the test). An example to walk through how to handle MultiDiscrete action spaces would be extremely helpful.

