Restore and serve from remote checkpoit

How severe does this issue affect your experience of using Ray?

  • Medium: It contributes to significant difficulty to complete my task, but I can work around it.

I have used Ray Tune to train and sync checkpoints to S3.

I want to restore a PPO Agent with a checkpoint stored in S3.

The Trainable class in ray/trainable.py at ray-1.12.0 · ray-project/ray · GitHub

has the following code:

    def restore(self, checkpoint_path):
        """Restores training state from a given model checkpoint.
        These checkpoints are returned from calls to save().
        Subclasses should override ``load_checkpoint()`` instead to
        restore state.
        This method restores additional metadata saved with the checkpoint.
        `checkpoint_path` should match with the return from ``save()``.
        `checkpoint_path` can be
        `~/ray_results/exp/MyTrainable_abc/
        checkpoint_00000/checkpoint`. Or,
        `~/ray_results/exp/MyTrainable_abc/checkpoint_00000`.
        `self.logdir` should generally be corresponding to `checkpoint_path`,
        for example, `~/ray_results/exp/MyTrainable_abc`.
        `self.remote_checkpoint_dir` in this case, is something like,
        `REMOTE_CHECKPOINT_BUCKET/exp/MyTrainable_abc`

However, I am unsure the best practice for implementing when using the following Serve deployment:

from ray import serve
from starlette.requests import Request
import ray.rllib.agents.ppo as ppo

@serve.deployment(route_prefix="/cartpole-ppo")
class ServePPOModel:
    def __init__(self, checkpoint_path) -> None:
        self.trainer = ppo.PPOTrainer(
            config={
                "framework": "torch",
                "num_workers": 0,
            },
            env="CartPole-v0",
        )
        self.uses_cloud_checkpointing = True
        self.remote_checkpoint_dir = checkpoint_path
        
        self.trainer.restore(checkpoint_path)

    async def __call__(self, request: Request):
        json_input = await request.json()
        obs = json_input["observation"]

        action = self.trainer.compute_single_action(obs)
        return {"action": int(action)}

Are there best practices for doing this?

Hi @peterhaddad3121, thank you for your question!

I want to share two recommendation here: