How severe does this issue affect your experience of using Ray?
- Medium: It contributes to significant difficulty to complete my task, but I can work around it.
I have used Ray Tune to train and sync checkpoints to S3.
I want to restore a PPO Agent with a checkpoint stored in S3.
The Trainable class in ray/trainable.py at ray-1.12.0 · ray-project/ray · GitHub
has the following code:
def restore(self, checkpoint_path):
"""Restores training state from a given model checkpoint.
These checkpoints are returned from calls to save().
Subclasses should override ``load_checkpoint()`` instead to
restore state.
This method restores additional metadata saved with the checkpoint.
`checkpoint_path` should match with the return from ``save()``.
`checkpoint_path` can be
`~/ray_results/exp/MyTrainable_abc/
checkpoint_00000/checkpoint`. Or,
`~/ray_results/exp/MyTrainable_abc/checkpoint_00000`.
`self.logdir` should generally be corresponding to `checkpoint_path`,
for example, `~/ray_results/exp/MyTrainable_abc`.
`self.remote_checkpoint_dir` in this case, is something like,
`REMOTE_CHECKPOINT_BUCKET/exp/MyTrainable_abc`
However, I am unsure the best practice for implementing when using the following Serve deployment:
from ray import serve
from starlette.requests import Request
import ray.rllib.agents.ppo as ppo
@serve.deployment(route_prefix="/cartpole-ppo")
class ServePPOModel:
def __init__(self, checkpoint_path) -> None:
self.trainer = ppo.PPOTrainer(
config={
"framework": "torch",
"num_workers": 0,
},
env="CartPole-v0",
)
self.uses_cloud_checkpointing = True
self.remote_checkpoint_dir = checkpoint_path
self.trainer.restore(checkpoint_path)
async def __call__(self, request: Request):
json_input = await request.json()
obs = json_input["observation"]
action = self.trainer.compute_single_action(obs)
return {"action": int(action)}
Are there best practices for doing this?