Hi Folks! how do i load pytorch .pth checkpoint created outside of ray in ray training ? Ray trainer checkpoints only seem to be accepting pkl files
Hi @vgill , you don’t need to use Ray Trainer checkpoints. For example, you can directly use torch.load()
in your train_loop_per_worker
.
def train_loop_per_worker():
# load external ckpt
state_dict = torch.load("ckpt.pth")
...
with tempfile.TemporaryDirectory() as tmpdir:
torch.save(state_dict, f"{tmpdir}/ckpt_new.pth")
# report ckpt to AIR session
session.report(metrics, TorchCheckpoint.from_directory(tmpdir))