Applying action mask for DQNTrainer with 'hiddens' a non-empty list doesn't work

How severe does this issue affect your experience of using Ray?

  • High: It blocks me to complete my task.

I use

  • Python 3.9
  • Ray 1.13
  • Windows 10

I have a custom model class that inherits from TorchModelV2. This is not the full code. I have abstracted away the irrelevant parts and only highlighted the relevant ones.

from ray.rllib.utils.torch_utils import FLOAT_MIN

class CustomNetwork(TorchModelV2, nn.Module):
    def __init__(self,
                 obs_space: gym.spaces.Space,
                 action_space: gym.spaces.Space,
                 num_outputs: int,
                 model_config: ModelConfigDict,
                 name: str):
        TorchModelV2.__init__(self, obs_space, action_space, num_outputs, model_config, name)  # TorchModelV2 is just a container and doesn't really do anything. Therefore, we need nn.Module to have a real neural network.
        nn.Module.__init__(self)
       
        # some code ...
        
        self.model = nn.Sequential(..)

    def forward(self,
                input_dict: Dict[str, TensorType],
                state: List[TensorType],
                seq_lens: TensorType) -> (TensorType, List[TensorType]):
        obs = input_dict["obs"]
        model_input = obs['model_input']
        action_mask = obs['action_mask']
        q_values = self.model(model_input)
        q_values.masked_fill_(action_mask == 0, FLOAT_MIN)
        return q_values

Then when using the DQNTrainer in tune there is a wrapper class DQNTorchModel that goes around it.

from ray.rllib.agents.dqn.dqn import DEFAULT_CONFIG as DQN_DEFAULT_CONFIG, DQNTrainer
from ray.rllib.models import ModelCatalog
config = DQN_DEFAULT_CONFIG.copy()
# A lot of other parameters ... 
# Then the relevant parameters
network_name = 'CustomNetwork'
ModelCatalog.register_custom_model(network_name, CustomNetwork)
config['model'] = {"custom_model": network_name}
config['hiddens'] = [32] 
tune.run(DQNTrainer, ...)

This wrapper class gives an error because it makes an extra layer after the model_output and does a forward pass with this. And then some values are already set to FLOAT_MIN so this makes sense to me.

# This code is Ray source code and can be found in ray.rllib.agents.dqn.dqn_torch_model.py
class DQNTorchModel(TorchModelV2, nn.Module):

    def __init__(...): 
        ...
   
    def get_q_value_distributions(self, model_out):
        action_scores = self.advantage_module(model_out)
        ....
        return action_scores, logits, logits

How I was planning to fix this: If I just overwrite the DQNTorchModel and the get_q_value_distributions function then I can apply the action mask after the action_scores and remove it from my def forward function.
Problem: I don’t have access to the action mask in the get_q_ value_distributions function.

Question: Given this case, how would I apply an action mask after the action_scores?