[Tune] [RLlib] Save On Best Training Reward Callback

Hello,
I am new to ray ralib. I would like to move from stable baselines to rlib.
My question: Is it possible to create a callback in rlib and save model on best result like it is done here:
SaveOnBestTrainingRewardCallback
https://stable-baselines3.readthedocs.io/en/master/guide/examples.html

@sven1977 what are the best practices here?

It depends on the problem. You should try and check if save best reward callback improves your model or not. Anyway, is it possible to do that with rllib as in stable baselines?

Hey @Hannan , great question and sorry for the delay, which was caused by the question being “uncategorized”. It helps if you set a category (e.g. “RLlib”) when you post a new question. That way, we’ll find it more easily and can assign the right person to answer it.

Hey @rliaw , is there such a callback in Tune that gets triggered when a new best-reward trial has been found?

There is, it’s an arg passed to ray.tune.run

from ray import train
from ray.train import CheckpointStrategy, Trainer


def train_func():
    # first checkpoint
    train.save_checkpoint(loss=2)
    # second checkpoint
    train.save_checkpoint(loss=4)
    # third checkpoint
    train.save_checkpoint(loss=1)
    # fourth checkpoint
    train.save_checkpoint(loss=3)

# Keep the 2 checkpoints with the smallest "loss" value.
checkpoint_strategy = CheckpointStrategy(num_to_keep=2,
                                         checkpoint_score_attribute="loss",
                                         checkpoint_score_order="min")

trainer = Trainer(backend="torch", num_workers=2)
trainer.start()
trainer.run(train_func, checkpoint_strategy=checkpoint_strategy)
print(trainer.best_checkpoint_path)
# /home/ray_results/train_2021-09-01_12-00-00/run_001/checkpoints/checkpoint_000003
print(trainer.latest_checkpoint_dir)
# /home/ray_results/train_2021-09-01_12-00-00/run_001/checkpoints
print([checkpoint_path for checkpoint_path in trainer.latest_checkpoint_dir.iterdir()])
# [PosixPath('/home/ray_results/train_2021-09-01_12-00-00/run_001/checkpoints/checkpoint_000003'),
# PosixPath('/home/ray_results/train_2021-09-01_12-00-00/run_001/checkpoints/checkpoint_000001')]
trainer.shutdown()