How severe does this issue affect your experience of using Ray?
- High: It blocks me to complete my task.
I use
- Python 3.9
- Ray 1.13
- Windows 10
I have a custom model class that inherits from TorchModelV2. This is not the full code. I have abstracted away the irrelevant parts and only highlighted the relevant ones.
from ray.rllib.utils.torch_utils import FLOAT_MIN
class CustomNetwork(TorchModelV2, nn.Module):
def __init__(self,
obs_space: gym.spaces.Space,
action_space: gym.spaces.Space,
num_outputs: int,
model_config: ModelConfigDict,
name: str):
TorchModelV2.__init__(self, obs_space, action_space, num_outputs, model_config, name) # TorchModelV2 is just a container and doesn't really do anything. Therefore, we need nn.Module to have a real neural network.
nn.Module.__init__(self)
# some code ...
self.model = nn.Sequential(..)
def forward(self,
input_dict: Dict[str, TensorType],
state: List[TensorType],
seq_lens: TensorType) -> (TensorType, List[TensorType]):
obs = input_dict["obs"]
model_input = obs['model_input']
action_mask = obs['action_mask']
q_values = self.model(model_input)
q_values.masked_fill_(action_mask == 0, FLOAT_MIN)
return q_values
Then when using the DQNTrainer in tune there is a wrapper class DQNTorchModel
that goes around it.
from ray.rllib.agents.dqn.dqn import DEFAULT_CONFIG as DQN_DEFAULT_CONFIG, DQNTrainer
from ray.rllib.models import ModelCatalog
config = DQN_DEFAULT_CONFIG.copy()
# A lot of other parameters ...
# Then the relevant parameters
network_name = 'CustomNetwork'
ModelCatalog.register_custom_model(network_name, CustomNetwork)
config['model'] = {"custom_model": network_name}
config['hiddens'] = [32]
tune.run(DQNTrainer, ...)
This wrapper class gives an error because it makes an extra layer after the model_output
and does a forward pass with this. And then some values are already set to FLOAT_MIN so this makes sense to me.
# This code is Ray source code and can be found in ray.rllib.agents.dqn.dqn_torch_model.py
class DQNTorchModel(TorchModelV2, nn.Module):
def __init__(...):
...
def get_q_value_distributions(self, model_out):
action_scores = self.advantage_module(model_out)
....
return action_scores, logits, logits
How I was planning to fix this: If I just overwrite the DQNTorchModel
and the get_q_value_distributions
function then I can apply the action mask after the action_scores
and remove it from my def forward
function.
Problem: I don’t have access to the action mask in the get_q_ value_distributions
function.
Question: Given this case, how would I apply an action mask after the action_scores
?