[Rllib, Tune, AIR] Checkpointing as per custom metric minimum

Hi Ray community,

I am using successfully the checkpointing combination of RLlib, Tune and Air, as shown in the example below.

It would be beneficial for me to create a checkpoint not in a fixed interval, like every 10th iteration, but every time I reach a better value for a metric, e.g. “episode_reward_max” or a even a custom metric.

So in words, it would be like “create a checkpoint every time you reach a lower value for episode_reward_max than you did in all iterations before”.

Has anyone experience with that?

Example of a fixed interval of 10 iterations to create checkpoint:

tuner = tune.Tuner(
    "PPO",
    param_space=config,
    run_config=air.RunConfig(
        checkpoint_config=air.CheckpointConfig(
            checkpoint_score_attribute="episode_reward_mean",
            checkpoint_score_order="max",
            checkpoint_frequency=10,
            checkpoint_at_end=True,
        ),
        tune_config=tune.TuneConfig(num_samples=2)
)

The following code snippet is something what could help me. The idea is that always the last 2 checkpoints are kept on disk, where training reached a new minimum value of the custom metric number. The custom metric is recorded in the on_episode_end() of a custom callback. However, I observe that only the checkpoint_at_end=True becomes effective, and a single checkpoint is created after trial end.

Which piece is missing?

tuner = tune.Tuner(
    "PPO",
    param_space=config,
    run_config=air.RunConfig(
        checkpoint_config=air.CheckpointConfig(
            num_to_keep=2,
            checkpoint_score_attribute="custom_metrics/number_min",
            checkpoint_score_order="min",
            checkpoint_at_end=True,
        ),
        tune_config=tune.TuneConfig(num_samples=2)
)