Ray Train creates TypeError: 'generator' object is not subscriptable

I changed how I extract the data from the batch object (thanks Antoni Yard1 Baum for the help on Slack), and manually move the data to the right GPU:

BEFORE (not working)

train_loader = train.torch.prepare_data_loader(train_loader)
inputs = batch[0]
masks = batch[1]

AFTER (working)

train_loader = train.torch.prepare_data_loader(train_loader, move_to_device=False)
device = train.torch.get_device()

inputs, masks = batch
inputs = inputs.to(device)
masks = masks.to(device)