Masking in custom autoregressive ActionDistribution

Hi,
I want to add masking to my custum autoregressive ActionDistribution but I don’t see any way to get the Mask. Normally I provide the mask from the env as an input to the custom model. But now with the sampling in the ActionDistribution, I don’t know how to get the mask and add it to the action_head. Is there any sample or idea how to do it?

class CustomMultiCategorical(MultiCategorical):

def __init__(
        self,
        inputs: List[TensorType],
        model: ModelV2,
        input_lens: Union[List[int], np.ndarray, Tuple[int, ...]] = [VOCAB_NUM_CLASSES] * MAX_SENTENCE_LEN,
        action_space=None, *args, **kwargs
):
    # skip TFActionDistribution init
    ActionDistribution.__init__(self, inputs, model)

    self.cats = []
    self.sampled_cats = []
    batch_size = tf.shape(inputs)[0]
    last_cat_res = tf.fill((batch_size, 1), VOCAB_NUM_CLASSES)
    decoder_states = (None, None)
    for i, _ in enumerate(tf.split(inputs, input_lens, axis=1)):
        last_cat, decoder_states = self._get_autoregressive_categorical(inputs, model, last_cat_res, decoder_states)
        self.cats.append(last_cat)
        last_cat_res = last_cat.sample()
        self.sampled_cats.append(last_cat_res)

    self.action_space = action_space
    if self.action_space is None:
        self.action_space = gym.spaces.MultiDiscrete(
            [c.inputs.shape[1] for c in self.cats]
        )
    self.sample_op = self._build_sample_op()
    self.sampled_action_logp_op = self.logp(self.sample_op)

def _get_autoregressive_categorical(self, input_, model, last_cat_res, decoder_states):
    logits, state_h, state_c = model.nlp_action_head(input_, last_cat_res, decoder_states[0], decoder_states[1])
    return Categorical(logits, model), (state_h, state_c)

@staticmethod
@override(ActionDistribution)
def required_model_output_shape(
        action_space: gym.Space, model_config: ModelConfigDict
) -> Union[int, np.ndarray]:
    # Int Box.
    if isinstance(action_space, gym.spaces.Box):
        assert action_space.dtype.name.startswith("int")
        low_ = np.min(action_space.low)
        high_ = np.max(action_space.high)
        assert np.all(action_space.low == low_)
        assert np.all(action_space.high == high_)
        return np.prod(action_space.shape, dtype=np.int32) * (high_ - low_ + 1)
    # MultiDiscrete space.
    else:
        # nvec is already integer, so no casting needed.
        return np.sum(action_space.nvec)

It seems as if I can save the mask in the model and use it in the action distribution. But I still need to test if it is working properly.