Passing additional tensors from custom_actions to loss function that are not part of the model.view_requirements

I have a model where masks are being generated on-the-fly in a overridden custom_action function. The mask cannot be generated in the model as the mask generation algorithm requires access to data in the info_batch.

The mask then needs to be passed to the loss function and the only way to do that is to place the mask tensor in the extra_fetches variable in the custom_actions function.

A problem arises here however as during template generation the mask is not included as part of the view_requirements (and shouldnt be because its not a view requirment of the model).

Due to this the template generation fails as the action mask tensor is missing from the train_batch in the loss function.

I’m currently working around this by just adding a default tensor in the loss function:

valid_action_mask = train_batch.get(['valid_action_mask'], torch.zeros(....))

Is there a better “proper” way of doing this or is this an expected solution to this problem?


This method works for the first few sample, however. rllib then seems to magically inject the ‘valid_action_mask’ itself, which totally breaks the logic above.

I have no idea why this is happening.