ApexTrainer fails to instantiate due to mat1 mat2 shapes cannot be multiplied

How severe does this issue affect your experience of using Ray?

  • High: It blocks me to complete my task.

Hi,

I have a custom environment and model which work fine for PPO and IMAPALA. I am trying to use the same setup for ApexDQN, but I am having some trouble.

When I instantiate the ApexTrainer class, I get the following error. It seems to be coming from within model.get_q_value_distributions(), which takes the output of model().

  File "/scratch/zciccwf/py36/envs/ddls/lib/python3.9/site-packages/ray/rllib/evaluation/rollout_worker.py", line 613, in __init__
    self._build_policy_map(                                                                                                                                                                               
  File "/scratch/zciccwf/py36/envs/ddls/lib/python3.9/site-packages/ray/rllib/evaluation/rollout_worker.py", line 1803, in _build_policy_map                                                               
    self.policy_map.create_policy(                                                                                                                                                                        
  File "/scratch/zciccwf/py36/envs/ddls/lib/python3.9/site-packages/ray/rllib/policy/policy_map.py", line 123, in create_policy                               
    self[policy_id] = create_policy_for_framework(                                                                                                                                                        
  File "/scratch/zciccwf/py36/envs/ddls/lib/python3.9/site-packages/ray/rllib/utils/policy.py", line 80, in create_policy_for_framework                                 
    return policy_class(observation_space, action_space, merged_config)                                                                                                                                   
  File "/scratch/zciccwf/py36/envs/ddls/lib/python3.9/site-packages/ray/rllib/policy/policy_template.py", line 330, in __init__                            
    self._initialize_loss_from_dummy_batch(                                                                                                                                                               
  File "/scratch/zciccwf/py36/envs/ddls/lib/python3.9/site-packages/ray/rllib/policy/policy.py", line 1053, in _initialize_loss_from_dummy_batch                   
    actions, state_outs, extra_outs = self.compute_actions_from_input_dict(                                                                                                                               
  File "/scratch/zciccwf/py36/envs/ddls/lib/python3.9/site-packages/ray/rllib/policy/torch_policy.py", line 320, in compute_actions_from_input_dict        
    return self._compute_action_helper(                                                                                                                                                                   
  File "/scratch/zciccwf/py36/envs/ddls/lib/python3.9/site-packages/ray/rllib/utils/threading.py", line 24, in wrapper                                                      
    return func(self, *a, **k)                                                                                                                   
  File "/scratch/zciccwf/py36/envs/ddls/lib/python3.9/site-packages/ray/rllib/policy/torch_policy.py", line 953, in _compute_action_helper                                     
    dist_inputs, dist_class, state_out = self.action_distribution_fn(                                                                                                                                     
  File "/scratch/zciccwf/py36/envs/ddls/lib/python3.9/site-packages/ray/rllib/algorithms/dqn/dqn_torch_policy.py", line 234, in get_distribution_inputs_and_class
    q_vals = compute_q_values(                                                                                                                                                                            
  File "/scratch/zciccwf/py36/envs/ddls/lib/python3.9/site-packages/ray/rllib/algorithms/dqn/dqn_torch_policy.py", line 424, in compute_q_values                      
    (action_scores, logits, probs_or_logits) = model.get_q_value_distributions(                                                                                                                           
  File "/scratch/zciccwf/py36/envs/ddls/lib/python3.9/site-packages/ray/rllib/algorithms/dqn/dqn_torch_model.py", line 146, in get_q_value_distributions                                     
    action_scores = self.advantage_module(model_out)                                                                                                                                                      
  File "/scratch/zciccwf/py36/envs/ddls/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl                                                   
    return forward_call(*input, **kwargs)                                                                  
  File "/scratch/zciccwf/py36/envs/ddls/lib/python3.9/site-packages/torch/nn/modules/container.py", line 139, in forward                                                            
    input = module(input)                                                                                                                                                                                 
  File "/scratch/zciccwf/py36/envs/ddls/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl                            
    return forward_call(*input, **kwargs)                                                                                                                                                                  
  File "/scratch/zciccwf/py36/envs/ddls/lib/python3.9/site-packages/ray/rllib/models/torch/misc.py", line 169, in forward                           
    return self._model(x)                                                                                                                                                                                 
  File "/scratch/zciccwf/py36/envs/ddls/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl                            
    return forward_call(*input, **kwargs)                                                                                                                                                                 
  File "/scratch/zciccwf/py36/envs/ddls/lib/python3.9/site-packages/torch/nn/modules/container.py", line 139, in forward                             
    input = module(input)                                                                                                                                                                                 
  File "/scratch/zciccwf/py36/envs/ddls/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl                            
    return forward_call(*input, **kwargs)                                                                                                                          
  File "/scratch/zciccwf/py36/envs/ddls/lib/python3.9/site-packages/torch/nn/modules/linear.py", line 114, in forward                               
    return F.linear(input, self.weight, self.bias)                                                                                                         
RuntimeError: mat1 and mat2 shapes cannot be multiplied (32x17 and 16x256) 

The problem is, I am not sure what model.get_q_value_distributions() method is being called inside compute_q_values, because it does not seem to be the custom model I have built for PPO and IMPALA (when I try to print from within my custom model’s get_q_value_distributions() nothing is printed, so I assume it is never called) and which I am now passing into the ApexTrainer config.

Does anyone know what might be causing this 32x17 and 16x256 mismatch? model() is outputting 17 dimensions because the action_space.n of my environment is 17, but for some reason model.get_q_value_distributions() expects 16 dimensions. I am not sure how to change what model.get_q_value_distributions() expects or what model is being used when this method is called.

Thanks in advance for your help!

I think the issue is that num_outputs being passed into my custom model is wrong - it is 16 rather than 17. Does anyone know where RLlib sets num_outputs? I expected that this would be automatically done as num_outputs=action_space.n, but it is not. This does not affect me with PPO and IMPALA because my custom model automatically outputs action_space.n logits, but ApexDQN uses additional advantage modules etc., so I need to set num_outputs correctly for RLlib.

So I think I have figured one thing out - for policy gradient methods such as PPO and IMPALA num_outputs=action_space.n, but for DQN, num_outputs=DQN_CONFIG.hiddens. The advantage_module and value_module used internally inside DQN then take in my custom model’s num_outputs and output action_space.n. Therefore, to fix my issue, I have to set my custom model’s output logits to be equal to DQN_CONFIG.hiddens.

However, this raises another issue. Currently in my custom model, inside forward(), I am using action masking as follows:

inf_mask = torch.maximum(
                         torch.log(input_dict['obs']['action_mask']).to(device), 
                         torch.tensor(torch.finfo(torch.float32).min).to(device)
                        ).to(device)
logits += inf_mask

However, now that my custom model’s forward method must return an output ready for the 256 units in DQN’s internal advantage and value functions, I cannot do action masking inside my model’s forward method.

What is the best way to implement action masking with DQN?

Hi @cwfparsonson ,

Thanks for reporting this. Such mismatches should not be induced by your environment.
If you are not providing a custom model, please open an issue in our repo.

Cheers

I’m facing this issue, but I’m not using a custom algorithm. Any headway on this?

Hi @Aidan_McLaughlin ,

Please post a repro script.
Thanks.