Restoring a RLModule checkpoint with pytorch

How severe does this issue affect your experience of using Ray?

  • High: It blocks me to complete my task.

Hi, I am migrating a successful PPO app (training & a variety of inference scripts) from the old API to new API. Training is now working, and can store checkpoints using Algorithm.save() (I could alternatively use RLModule.save_to_path(), but I don’t know what the difference might be or which is better; if this provides an advantage, please enlighten me).

My goal is to have an inference program built solely on pytorch, with no Ray overhead involved, so it will need to rely just on the torch methods for finding and restoring the model’s state_dict (e.g. using torch.load()). I had a good thing going on with checkpoints from the old API, but it now appears the checkpoint file structure is totally different. In particular, I see no .pt file! Can someone please provide some insight into what are the essential files for my use case and how to deal with it/them?

Edit: The Ray docs at Checkpointing — Ray 2.42.1 almost gets there. It does a good job of describing how to use checkpoints in a Ray environment, and even mentions that they can be deployed to pytorch-only inference installations, but never describes how the files can be used directly by pytorch. Maybe I’m blind and this is so simple it needs no explanation? But I feel that torch.load() can’t handle Ray’s .pkl files directly.

Thank you.

Solved! There are several files named module_state.pkl out there, but you need to find the one under the learner and the policy of interest. This holds the state dict and nothing else, so loading it is trivial, and no need to worry about the security complications of loading a .pt file.

        import pickle

        try:
            with open("<checkpoint_path>/learner_group/learner/rl_module/default_policy/module_state.pkl", "rb") as f:
                sd = pickle.load(f)
        except Exception as e:
            # handle exception
            raise e