Masking Invalid Actions for DQN Algorithm

I am not able to successfully mask all the invalid actions using DQN Algorithm.
Please find the custom model created

class ActionMaskModel(DistributionalQTFModel):

def __init__(self, obs_space, action_space, num_outputs,
             model_config, name, true_obs_shape=(870,), 
    super(ActionMaskModel, self).__init__(obs_space, action_space, num_outputs, model_config, name, **kw)
    print("action_space :", action_space)
    print("model_config :", model_config)
    self.action_embed_model = FullyConnectedNetwork(
        Box(-1, 1, shape=true_obs_shape),
        name + "_action_embed",

def forward(self, input_dict, state, seq_lens):

    # Extract the available actions tensor from the observation.
    avail_actions = input_dict["obs"]["avail_actions"]
    action_mask = input_dict["obs"]["action_mask"]

    # Compute the predicted action embedding
    action_embed, _ = self.action_embed_model({"obs": input_dict["obs"]["state"]})

    # Expand the model output to [BATCH, 1, EMBED_SIZE]. Note that the
    # avail actions tensor is of shape [BATCH, MAX_ACTIONS, EMBED_SIZE].
    intent_vector = tf.expand_dims(action_embed, 2)

    # Batch dot product => shape of logits is [BATCH, MAX_ACTIONS].
    action_logits = tf.reduce_sum(avail_actions * intent_vector, axis=1)

    # Mask out invalid actions (use tf.float32.min for stability)
    inf_mask = tf.maximum(tf.math.log(action_mask), tf.float32.min)
    return action_logits + inf_mask, state

Am i missing anything here ? I am even inheriting the DistributionalQTFModel

Hi @Archana_R,

Action making will not work like that for DQN. Those values you are masking are intermediate activations they are not the policy logits for the actions. The type of making you are doing only works for the on policy models like PPO or A2C.

How do you suggest I do it for DQN ?