cc @sven1977
Pytorch automatic mixed precision (AMP) provides a speedup of 2-3x on modern GPUs. I think this is low-hanging fruit for RLlib.
It is very simple to implement:
Instead of
model(input_dict, ...)
it is
with torch.cuda.amp.autocast():
model(input_dict, ...)
During the loss update, the gradscaler optimizer fixes any issues arising from over/underflow:
loss.backward()
optimizer.step()
becomes
scaler = torch.cuda.amp.GradScaler()
scaler.scale(loss).backward()
scaler.step(optimizer)
Would it be possible to have this worked into the Trainer
config? I see it’s already implemented for RaySGD.