Hi @jmugan,
Sorry for the delayed response. I had a busy weekend. I would think you might try something like this:
class MyVeryOwnDQNTorchModel(DQNTorchModel):
def __init__(self, self,
obs_space: gym.spaces.Space,
action_space: gym.spaces.Space,
num_outputs: int,
model_config: ModelConfigDict,
name: str,
*,
q_hiddens: Sequence[int] = (256, ),
dueling: bool = False,
dueling_activation: str = "relu",
num_atoms: int = 1,
use_noisy: bool = False,
v_min: float = -10.0,
v_max: float = 10.0,
sigma0: float = 0.5,
add_layer_norm: bool = False):
super(MyVeryOwnDQNTorchModel, self).__init__(
obs_space,
action_space,
num_outputs,
model_config,
name,
q_hiddens,
dueling,
dueling_activation,
num_atoms,
use_noisy,
v_min,
v_max,
sigma0,
add_layer_norm)
self.inf_mask = None
def forward(self, input_dict, state, seq_lens):
#self.in_mask = #your logic goes here
return super(DQNTorchModel, self).forward(input_dict, state, seq_lens)
def get_q_value_distributions(self, model_out):
"""
Returns distributional values for Q(s, a) given a state embedding.
Override this in your custom model to customize the Q output head.
Args:
model_out (Tensor): Embedding from the model layers.
Returns:
(action_scores, logits, dist) if num_atoms == 1, otherwise
(action_scores, z, support_logits_per_action, logits, dist)
"""
pass
#action_scores, logits, *rest = super(MyVeryOwnDQNTorchModel, self).get_q_value_distributions(model_out)
# apply masking here using self.inf_mask
# return ...