Action masking & Dict observation space & 'avail_actions'?

Dear folks,
Currently I am creating my own environment to create somewhat user friendly RL control experiments for my university project.
For me it is very useful to use a dictionary as an observation space as this is an easy way of managing with the states/observations. However I also like to use action masking. It seems that with the standard example this is pretty difficult. It is really important that I avoid certain actions for certain observations. Not to decrease their probability.
If I use this model example below, in the init function, the FullyConnectedNetwork/TorchFC is made by taking the shape of orig_space['observations']. When having a Dict observation space this is a Dict object, then the shape is taken somewhere in the FullyConnectedNetwork/TorchFC to initialize the neural network which errors because Dict spaces don’t have a shape. The object obs_space can’t be used because here the action_mask is flattened together with the observation space.
Secondly, in the forward pass, a dict object is passed each forward pass which is not a format the NN can handle.

For now I solved this by making a Box object in the init with the flattened shape of orig_space[‘observations’] and passing this in the obs_space variable in FullyConnectedNetwork/TorchFC. In the forward pass I turn the given dictionary into a 2d array with on the first dimension the batches. To me this method sounds very hacky and I am unsure if I am over engineering this.

Torch code below, (I use PPO with discrete actions, the action distribution is for now a simple Discrete(21) object.:

from gymnasium.spaces import Dict

from ray.rllib.models.tf.fcnet import FullyConnectedNetwork
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.models.torch.fcnet import FullyConnectedNetwork as TorchFC
from ray.rllib.utils.framework import try_import_tf, try_import_torch
from ray.rllib.utils.torch_utils import FLOAT_MIN
from gymnasium.spaces import Box
import numpy as np
tf1, tf, tfv = try_import_tf()
torch, nn = try_import_torch()

  
def length_of_observation_space(dict_space):
    """
====================================================================================================
Computes the length of the observation space if we would flatten it out.
Input:
dict_space: Dict gym space

output:
flattened_length: length of the flattened space
====================================================================================================
    """

    flattened_length = 0
    for key in dict_space:
            flattened_length += np.prod(dict_space[key].shape)
    return flattened_length

class TorchActionMaskModel(TorchModelV2, nn.Module):
    """
====================================================================================================

    PyTorch version of above ActionMaskingModel."""

    def __init__(
        self,
        obs_space,
        action_space,
        num_outputs,
        model_config,
        name,
        **kwargs,
    ):
        orig_space = getattr(obs_space, 'original_space', obs_space)
        assert (
            isinstance(orig_space, Dict)
            and 'action_mask' in orig_space.spaces
            and 'observations' in orig_space.spaces
        )
        TorchModelV2.__init__(
            self, obs_space, action_space, num_outputs, model_config, name, **kwargs
        )

        self.flattened_length = length_of_observation_space(orig_space['observations'])

        nn.Module.__init__(self)
        self.internal_model = TorchFC(
            Box(-np.inf, np.inf, shape=(self.flattened_length,)),
            action_space,
            num_outputs,
            model_config,
            name + '_internal',
        )

        # disable action masking --> will likely lead to invalid actions
        self.no_masking = False
        if 'no_masking' in model_config['custom_model_config']:
            self.no_masking = model_config['custom_model_config']['no_masking']

    def forward(self, input_dict, state, seq_lens):
        # Extract the available actions tensor from the observation.

        action_mask = input_dict['obs']['action_mask']

        # flatten the observation space. Note that this is a 2d array. The first dimension is the batch size.
        # the dictionary had per key a tensor with over the first dimension the batch size.
        # this way of flattening it is correct and I checked that it would retain the batches over the first dimension.
        flattened_inputs = torch.cat([t.view(t.shape[0], -1) for t in input_dict['obs']['observations'].values()], dim=1)

        # Compute the unmasked logits.
        logits, _ = self.internal_model({'obs': flattened_inputs})
        
        # If action masking is disabled, directly return unmasked logits
        if self.no_masking:
            return logits, state

        # Convert action_mask into a [0.0 || -inf]-type mask.
        inf_mask = torch.clamp(torch.log(action_mask), min=FLOAT_MIN)
        masked_logits = logits + inf_mask

        # Return masked logits.
        return masked_logits, state

    def value_function(self):
        return self.internal_model.value_function()

Furthermore, some examples have parametric actions, I always thought this has to do if I have a slightly more difficult action space, for example Tuple(Discrete(n), Discrete(n)) due to the ‘avail_actions’ key, however I keep being unsure about this as the documentation is not very clear about it. It would be amazing if someone can elaborate that further.

To sum up the two questions:

  • Did I properly use action masking?
  • What is the ‘avail_actions’ key exactly used for?

Thanks in advance,
Chris

Have you also looked at the full example?

You should write an environment that fits this model. You can simply wrap whatever environment you have to produce these dictionary observations.

I am not aware of an alternative for masking than to reduce the logits to -inf for unavailable actions. Do you have a concrete proposal? In practice, this should work, because the probability is below a zero with 38 trailing zeros. We are unlikely to encounter RLlib choosing one of these actions in our lifetime.

1 Like