Override get_q_value_distributions

Hi, I’m trying to mask actions for DQN.

Line 134 of ray/dqn_torch_model.py at master · ray-project/ray · GitHub says to override get_q_value_distributions, but when I do that it ignores my get_q_value_distributions and instead calls the one in DQNTorchModel. I guess this is because in build_q_model_and_distribution on line 169 of ray/dqn_torch_policy.py at master · ray-project/ray · GitHub it sets model_interface=DQNTorchModel.

So, how do I get it to use my get_q_value_distributions? Do I have to create a new model interface and pass that in instead? That’s pretty deep in the DQN abstraction, so I am probably confused about something. Thanks!

Hi @jmugan,

Two questions for you:

  1. I am sure you are doing it right but just as a double check how are you specifying the custom model in the config?

  2. Is your custom model a subclass of DQNTorchModel?
    If it is not this code here will make it one

The way it orders the multiple inheritance would find rllib’s version before yours:


The model is registered like this
ModelCatalog.register_custom_model( "unit_model", OurModel )
And OurModel inherits from FullyConnectedNetwork.

If I try to inherit from DQNTorchModel I don’t know how to make the custom forward method where I can put in self.inf_mask = torch.clamp(torch.log(action_mask), FLOAT_MIN, FLOAT_MAX)

Hi @jmugan,

Sorry for the delayed response. I had a busy weekend. I would think you might try something like this:

class MyVeryOwnDQNTorchModel(DQNTorchModel):

   def __init__(self,    self,
            obs_space: gym.spaces.Space,
            action_space: gym.spaces.Space,
            num_outputs: int,
            model_config: ModelConfigDict,
            name: str,
            *,
            q_hiddens: Sequence[int] = (256, ),
            dueling: bool = False,
            dueling_activation: str = "relu",
            num_atoms: int = 1,
            use_noisy: bool = False,
            v_min: float = -10.0,
            v_max: float = 10.0,
            sigma0: float = 0.5,
            add_layer_norm: bool = False):
                 super(MyVeryOwnDQNTorchModel, self).__init__(
                     obs_space,
                     action_space,
                     num_outputs,
                     model_config,
                     name,
                     q_hiddens,
                     dueling,
                     dueling_activation,
                     num_atoms,
                     use_noisy,
                     v_min,
                     v_max,
                     sigma0,
                     add_layer_norm) 
                 self.inf_mask = None
    
    def forward(self, input_dict, state, seq_lens):
          #self.in_mask = #your logic goes here
          return super(DQNTorchModel, self).forward(input_dict, state, seq_lens)


    def get_q_value_distributions(self, model_out):
        """
             Returns distributional values for Q(s, a) given a state embedding.
             Override this in your custom model to customize the Q output head.
            Args:
                model_out (Tensor): Embedding from the model layers.
            Returns:
                (action_scores, logits, dist) if num_atoms == 1, otherwise
                (action_scores, z, support_logits_per_action, logits, dist)
        """
          pass        
          #action_scores, logits, *rest  = super(MyVeryOwnDQNTorchModel, self).get_q_value_distributions(model_out)
         # apply masking here using self.inf_mask
         # return ...

Thanks! I tried something like that, but when I called super on the forward it said that DQNTorchModel didn’t have a forward implemented (NotImplementedError). Do you know why that would be? It’s probably something like the forward gets built automatically but when you go the custom model route that somehow disrupts that process. Or something. The model building part is very confusing to me.

I ended up subclassing my model from DQNTorchModel but defining the whole thing manually and ignoring q_hiddens. It worked a lot worse than PPO, even with discrete actions. So I probably did something wrong.

@jmugan,

The abstractions abound.

Here is a colab that you can build off of.

P.S. Sorry this is so ugly =|

Wow, very cool! Thanks @mannyv!