I changed how I extract the data from the batch object (thanks Antoni Yard1 Baum for the help on Slack), and manually move the data to the right GPU:
BEFORE (not working)
train_loader = train.torch.prepare_data_loader(train_loader)
inputs = batch[0]
masks = batch[1]
AFTER (working)
train_loader = train.torch.prepare_data_loader(train_loader, move_to_device=False)
device = train.torch.get_device()
inputs, masks = batch
inputs = inputs.to(device)
masks = masks.to(device)