How to load a pytorch .pth checkpoint in ray

Hi Folks! how do i load pytorch .pth checkpoint created outside of ray in ray training ? Ray trainer checkpoints only seem to be accepting pkl files

Hi @vgill , you don’t need to use Ray Trainer checkpoints. For example, you can directly use torch.load() in your train_loop_per_worker.

def train_loop_per_worker():
     # load external ckpt
    state_dict = torch.load("ckpt.pth")
    ...
    with tempfile.TemporaryDirectory() as tmpdir:
      torch.save(state_dict, f"{tmpdir}/ckpt_new.pth")
      # report ckpt to AIR session
      session.report(metrics, TorchCheckpoint.from_directory(tmpdir))