Compute_single_action randomly errors without changing input

  • Medium: It contributes to significant difficulty to complete my task, but I can work around it.
    Ray 2.5.0

I’m using a custom LSTM/action masking model based off of this Ray example:

My custom model

"""Custom LSTM + Action Mask model."""
# %% Imports
# Standard Library Imports
from typing import Dict, List, Tuple

# Third Party Imports
import gymnasium as gym
from ray.rllib.models.modelv2 import ModelV2
from ray.rllib.models.torch.misc import SlimFC
from ray.rllib.models.torch.recurrent_net import RecurrentNetwork as TorchRNN
from ray.rllib.policy.rnn_sequencing import add_time_dimension
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.torch_utils import FLOAT_MIN
from ray.rllib.utils.typing import TensorType

torch, nn = try_import_torch()

# %% Class
class MaskedLSTM(TorchRNN, nn.Module):
    """Fully-connected layers feed into an LSTM layer."""

    def __init__(
        obs_space: gym.spaces.Space,
        action_space: gym.spaces.Space,
        num_outputs: int,
        model_config: dict = None,
        name: str = None,
        """Initialize MaskedLSTM model.

            obs_space (gym.spaces.Space): Environment observation space.
            action_space (gym.spaces.Space): Environment action space.
            num_outputs (int): Number of outputs of model. Should be equal to size
                of flattened action_space.
            model_config (dict, optional): Used for Ray defaults. Defaults to {}.
            name (str, optional): Used for inheritance. Defaults to "MaskedLSTM".
            custom_model_kwargs: Configure size of FC net and LSTM layer. Required.

        Expected items in custom_model_kwargs:
            fcnet_hiddens (list[int]): Number and size of FC layers.
            fcnet_activation (str): Activation function for FC layers. See Ray
                SlimFC documentation for recognized args.
            lstm_state_size (int): Size of LSTM layer.
        # Convert space to proper gym space if handed is as a different type
        orig_space = getattr(obs_space, "original_space", obs_space)
        # Size of observations must include only "observations", not "action_mask".
        # Action mask must be 1d and same len as num_outputs.
        # custom_model_kwargs must include "lstm_state_size", "fcnet_hiddens",
        # and "fcnet_activation".
        assert "observations" in orig_space.spaces
        assert "action_mask" in orig_space.spaces
        assert len(orig_space.spaces) == 2
        assert len(orig_space["action_mask"].shape) == 1
        assert orig_space["action_mask"].shape[0] == num_outputs
        assert "lstm_state_size" in custom_model_kwargs
        assert "fcnet_hiddens" in custom_model_kwargs
        assert "fcnet_activation" in custom_model_kwargs

        lstm_state_size = custom_model_kwargs.get("lstm_state_size")

        # Defaults
        if model_config is None:
            model_config = {}
        if name is None:
            name = "MaskedLSTM"

        # Inheritance
            obs_space, action_space, num_outputs, model_config, name

        self.obs_size = orig_space["observations"].shape[0]
        # transition layer size: size of output of final hidden layer
        self.trans_layer_size = custom_model_kwargs["fcnet_hiddens"][-1]
        self.lstm_state_size = lstm_state_size

        self.fc_layers = self.makeFCLayers(

        self.lstm = nn.LSTM(
        self.action_branch = nn.Linear(self.lstm_state_size, num_outputs)
        self.value_branch = nn.Linear(self.lstm_state_size, 1)
        # Holds the current "base" output (before logits layer).
        self._features = None

    def get_initial_state(self):
        """Initial states of hidden layers are initial states of final FC layer."""
        h = [
  , self.lstm_state_size)
  , self.lstm_state_size)
        return h

    def value_function(self):  # noqa
        assert self._features is not None, "must call forward() first"
        return torch.reshape(self.value_branch(self._features), [-1])

    # Override forward() to add an action mask step
    def forward(
        input_dict: Dict[str, TensorType],
        state: List[TensorType],
        seq_lens: TensorType,
    ) -> Tuple[TensorType, List[TensorType]]:
        """Adds time dimension to batch before sending inputs to forward_rnn().

        You should implement forward_rnn() in your subclass.
        # When training, input_dict is handed in with an extra nested level from
        # the environment (input_dict["obs"]).
        # Get observations from obs; not observations+action_mask
        flat_inputs = input_dict["obs"]["observations"].float()
        action_mask = input_dict["obs"]["action_mask"]

        # Note that max_seq_len != input_dict.max_seq_len != seq_lens.max()
        # as input_dict may have extra zero-padding beyond seq_lens.max().
        # Use add_time_dimension to handle this
        self.time_major = self.model_config.get("_time_major", False)
        inputs = add_time_dimension(
        output, new_state = self.forward_rnn(inputs, state, seq_lens)
        output = torch.reshape(output, [-1, self.num_outputs])
        # Mask raw logits here! Then return masked values
        output = self.maskLogits(logits=output, mask=action_mask)
        return output, new_state

    def maskLogits(self, logits: TensorType, mask: TensorType):
        """Apply mask over raw logits."""
        inf_mask = torch.clamp(torch.log(mask), min=FLOAT_MIN)
        masked_logits = logits + inf_mask
        return masked_logits

    def forward_rnn(self, inputs, state, seq_lens):
        """Feeds `inputs` (B x T x ..) through the Gru Unit.

        Returns the resulting outputs as a sequence (B x T x ...).
        Values are stored in self._cur_value in simple (B) shape (where B
        contains both the B and T dims!).

            NN Outputs (B x T x ...) as sequence.
            The state batches as a List of two items (c- and h-states).
        x = nn.functional.relu(self.fc_layers(inputs))
        self._features, [h, c] = self.lstm(
            x, [torch.unsqueeze(state[0], 0), torch.unsqueeze(state[1], 0)]
        action_out = self.action_branch(self._features)
        return action_out, [torch.squeeze(h, 0), torch.squeeze(c, 0)]

    def makeFCLayers(
        self, model_config: dict, input_size: int
    ) -> nn.Sequential:
        """Make fully-connected layers.

        See Ray SlimFC for details.

            model_config (dict): {
                "fcnet_hiddens": (list[int]) Numer of hidden layers is number of
                    entries; size of hidden layers is values of entries,
                "fcnet_activation": (str) Recognized activation function
            input_size (int): Input layer size.

            nn.Sequential: Has N layers, where N = len(model_config["fcnet_hiddens"]).
        hiddens = list(model_config.get("fcnet_hiddens", []))
        activation = model_config.get("fcnet_activation")

        self.fc_hiddens = hiddens
        self.fc_activation = activation

        layers = []
        prev_layer_size = input_size

        # Create hidden layers.
        for size in hiddens:
            prev_layer_size = size

        fc_layers = nn.Sequential(*layers)

        return fc_layers

I am having issues when I use policy.compute_single_action() on a policy that uses my custom model, using the script below. The script is a close copy of the example above. I get the below shape error seemingly randomly, but more often than not. Sometimes script works correctly, but I cannot pin down the conditions on when it works correctly. I can literally run it fine, then run it immediately afterwards, with no changes, and the script will error out.

Relevant portion of script:

ModelCatalog.register_custom_model("MaskedLSTM", MaskedLSTM)

lstm_state_size = 10
config = (
    .environment(RandomEnv, env_config=env_config)
            # Specify our custom model from above.
            "custom_model": "MaskedLSTM",
            # Extra kwargs to be passed to your model's c'tor.
            "custom_model_config": {
                "fcnet_hiddens": [6, 6],
                "fcnet_activation": "relu",
                "lstm_state_size": lstm_state_size,
algo =
policy = algo.get_policy()
print(f"policy = {policy}")

# 2 ways of getting initial state, both should give the same list of zeros.
state = [zeros([lstm_state_size], float32) for _ in range(2)]
print(f"state = {state}")

action, state_out, _ = policy.compute_single_action(


Traceback (most recent call last):
  File "/tests/nets/", line 137, in <module>
    action, state_out, _ = policy.compute_single_action(
  File "/anaconda3/envs/punch/lib/python3.10/site-packages/ray/rllib/policy/", line 545, in compute_single_action
    out = self.compute_actions_from_input_dict(
  File "/anaconda3/envs/punch/lib/python3.10/site-packages/ray/rllib/policy/", line 526, in compute_actions_from_input_dict
    return self._compute_action_helper(
  File "/anaconda3/envs/punch/lib/python3.10/site-packages/ray/rllib/utils/", line 24, in wrapper
    return func(self, *a, **k)
  File "/anaconda3/envs/punch/lib/python3.10/site-packages/ray/rllib/policy/", line 1171, in _compute_action_helper
    action_dist = dist_class(dist_inputs, self.model)
  File "/anaconda3/envs/punch/lib/python3.10/site-packages/ray/rllib/models/torch/", line 114, in __init__
    self.cats = [
  File "/anaconda3/envs/punch/lib/python3.10/site-packages/ray/rllib/models/torch/", line 115, in <listcomp>
  File "/anaconda3/envs/punch/lib/python3.10/site-packages/torch/distributions/", line 66, in __init__
    super(Categorical, self).__init__(batch_shape, validate_args=validate_args)
  File "/anaconda3/envs/punch/lib/python3.10/site-packages/torch/distributions/", line 56, in __init__
    raise ValueError(
ValueError: Expected parameter logits (Tensor of shape (1, 2)) of distribution Categorical(logits: torch.Size([1, 2])) to satisfy the constraint IndependentConstraint(Real(), 1), but found invalid values:
tensor([[nan, nan]])

Given that the script does work sometimes, I think my model is probably implemented correctly, and I’m just writing my test poorly. But I could be wrong. Any help would be much appreciated.