Action masking error

How severe does this issue affect your experience of using Ray?

  • High: It blocks me to complete my task.

Hi, I am trying to implement a simple discrete action masking in rllib. The idea is to have the entire action space (182) available at the beginning of an episode and once the action has been chosen it cannot be picked again, so it’s an ever decreasing action space. I have read through the related examples on GitHub including and, as well as some previous posts on the forum so I do have some vague idea. However I keep getting the same error message when I am initiating the algo.

Also is it mandatory to set the framework to “framework”: “tf2” when calling the custom model? You can find snippets of the code below. Thank you very much!

class ActionMaskEnv(gym.Env):
    def __init__(self, env_config):
        super(ActionMaskEnv, self).__init__()
        # Define observation space
        self.observation_space = Dict({
            "observations": Box(low=0, high=1, shape=(910,), dtype=np.int32),
            "action_mask": Box(0, 1, shape=(182,)),
        # Define action space
        self.action_space = Discrete(182)
        # action mask (initially all actions possible)
        # updates mask at every step until entire array is zeros
        self.mask = np.ones(182, dtype=np.float32)

class ActionMaskModel(TFModelV2):
    """Model that handles simple discrete action masking.
    This assumes the outputs are logits for a single Categorical action dist.
    Getting this to work with a more complex output (e.g., if the action space
    is a tuple of several distributions) is also possible but left as an
    exercise to the reader.

    def __init__(
            self, obs_space, action_space, num_outputs, model_config, name, **kwargs
        orig_space = getattr(obs_space, "original_space", obs_space)
        assert (
                isinstance(orig_space, Dict)
                and "action_mask" in orig_space.spaces
                and "observations" in orig_space.spaces

        super().__init__(obs_space, action_space, num_outputs, model_config, name)

        self.internal_model = FullyConnectedNetwork(
            name + "_internal",

         # disable action masking --> will likely lead to invalid actions
        self.no_masking = model_config["custom_model_config"].get("no_masking", False)

    def forward(self, input_dict, state, seq_lens):
        # Extract the available actions tensor from the observation.
        action_mask = input_dict["obs"]["action_mask"]

        # Compute the unmasked logits.
        logits, _ = self.internal_model({"obs": input_dict["obs"]["observations"]})

        # If action masking is disabled, directly return unmasked logits
        if self.no_masking:
            return logits, state

        # Convert action_mask into a [0.0 || -inf]-type mask.
        inf_mask = tf.maximum(tf.math.log(action_mask), tf.float32.min)
        masked_logits = logits + inf_mask

        # Return masked logits.
        return masked_logits, state

    def value_function(self):
        return self.internal_model.value_function()
ModelCatalog.register_custom_model("action_mask_model", ActionMaskModel)
algo = dqn.DQN(env=ActionMaskEnv, config={
    "rollout_fragment_length": 100,
    "env_config": {},
    "hiddens": [],
    "model": {
        "custom_model": "action_mask_model",
    "train_batch_size": 1000,
    "framework": "tf2",
    "horizon": 182,
    "eager_tracing": True,
    "min_train_timesteps_per_iteration": 100,
    "min_sample_timesteps_per_iteration": 2000,