I have a model where masks are being generated on-the-fly in a overridden custom_action
function. The mask cannot be generated in the model as the mask generation algorithm requires access to data in the info_batch
.
The mask then needs to be passed to the loss function and the only way to do that is to place the mask tensor in the extra_fetches
variable in the custom_actions
function.
A problem arises here however as during template generation the mask is not included as part of the view_requirements
(and shouldnt be because its not a view requirment of the model).
Due to this the template generation fails as the action mask tensor is missing from the train_batch
in the loss function.
I’m currently working around this by just adding a default tensor in the loss function:
valid_action_mask = train_batch.get(['valid_action_mask'], torch.zeros(....))
Is there a better “proper” way of doing this or is this an expected solution to this problem?
Thanks