Ray restore checkpoint in rllib

Ray saves a bunch of checkpoints during a call of agent.train(). How do I know which one is the checkpoint with the best agent to load?
Is there any function like tune-analysis-output.get_best_checkpoint(path, mode="max") to explore different loading possibilities over the checkpoints?

1 Like

Yep, you can use

analysis = tune.Analysis(experiment_path)  # can also be the result of `tune.run()`

trial_logdir = analysis.get_best_logdir(metric="metric", mode="max")  # Can also just specify trial dir directly

checkpoints = analysis.get_trial_checkpoints_paths(trial_logdir)  # Returns tuples of (logdir, metric)
best_checkpoint = analysis.get_best_checkpoint(trial_logdir, metric="metric", mode="max")

See Analysis (tune.analysis) — Ray v2.0.0.dev0

1 Like

It seems that this is only available for ray.tune. My question is the availability of similar APIs when I use rllib.

Would it be a viable alternative for you to use

analysis = tune.run(
    agent,
    # ...
)

instead of repeated calls to agent.train()?

1 Like

Probably could be if I knew it in advance.
I now have trained an agent for several days :slight_smile: and looking to see how can I extract the best saved checkpoint.

1 Like

Hey @oroojlooy , that’s a good reason to do this w/o tune (having already used trainer.train() for 7 days :smiley: ). Yeah, trainer.train() should still log stuff to ~/ray_results/[your algo name] … There should be a progress.csv file and a results.json file (that one will have the return values of the trainer.train() calls, one per row). I guess you would have to go through these line-by-line and find the best iteration, then use the respective checkpoint to restore your agent via:

trainer = [SomeTrainerClass](config=my_config)
trainer.restore([path to checkpoint file])

Yuhh, that is indeed an good reason to use tune()!
The issue with this path is that I have several training running for different environment and algorithms, makes it a bit hard to do that by hand. Probably, I’ll follow a path in the middle, write a piece of code to read that csv file and find the best checkpoint and then use that checkpoint :smiley:

Anyways, thanks for the answer. At least I know that there is not any free solution for this.