I am not able to successfully mask all the invalid actions using DQN Algorithm.
Please find the custom model created
class ActionMaskModel(DistributionalQTFModel):
def __init__(self, obs_space, action_space, num_outputs,
model_config, name, true_obs_shape=(870,),
action_embed_size=144,**kw):
super(ActionMaskModel, self).__init__(obs_space, action_space, num_outputs, model_config, name, **kw)
print("action_space :", action_space)
print("model_config :", model_config)
self.action_embed_model = FullyConnectedNetwork(
Box(-1, 1, shape=true_obs_shape),
action_space,
action_embed_size,
model_config,
name + "_action_embed",
)
def forward(self, input_dict, state, seq_lens):
# Extract the available actions tensor from the observation.
avail_actions = input_dict["obs"]["avail_actions"]
action_mask = input_dict["obs"]["action_mask"]
# Compute the predicted action embedding
action_embed, _ = self.action_embed_model({"obs": input_dict["obs"]["state"]})
# Expand the model output to [BATCH, 1, EMBED_SIZE]. Note that the
# avail actions tensor is of shape [BATCH, MAX_ACTIONS, EMBED_SIZE].
intent_vector = tf.expand_dims(action_embed, 2)
# Batch dot product => shape of logits is [BATCH, MAX_ACTIONS].
action_logits = tf.reduce_sum(avail_actions * intent_vector, axis=1)
# Mask out invalid actions (use tf.float32.min for stability)
inf_mask = tf.maximum(tf.math.log(action_mask), tf.float32.min)
return action_logits + inf_mask, state
Am i missing anything here ? I am even inheriting the DistributionalQTFModel