[Rllib, Tune, AIR] Checkpointing as per custom metric minimum

Hi Ray community,

I am using successfully the checkpointing combination of RLlib, Tune and Air, as shown in the example below.

It would be beneficial for me to create a checkpoint not in a fixed interval, like every 10th iteration, but every time I reach a better value for a metric, e.g. “episode_reward_max” or a even a custom metric.

So in words, it would be like “create a checkpoint every time you reach a lower value for episode_reward_max than you did in all iterations before”.

Has anyone experience with that?

Example of a fixed interval of 10 iterations to create checkpoint:

tuner = tune.Tuner(
    "PPO",
    param_space=config,
    run_config=air.RunConfig(
        checkpoint_config=air.CheckpointConfig(
            checkpoint_score_attribute="episode_reward_mean",
            checkpoint_score_order="max",
            checkpoint_frequency=10,
            checkpoint_at_end=True,
        ),
        tune_config=tune.TuneConfig(num_samples=2)
)

The following code snippet is something what could help me. The idea is that always the last 2 checkpoints are kept on disk, where training reached a new minimum value of the custom metric number. The custom metric is recorded in the on_episode_end() of a custom callback. However, I observe that only the checkpoint_at_end=True becomes effective, and a single checkpoint is created after trial end.

Which piece is missing?

tuner = tune.Tuner(
    "PPO",
    param_space=config,
    run_config=air.RunConfig(
        checkpoint_config=air.CheckpointConfig(
            num_to_keep=2,
            checkpoint_score_attribute="custom_metrics/number_min",
            checkpoint_score_order="min",
            checkpoint_at_end=True,
        ),
        tune_config=tune.TuneConfig(num_samples=2)
)

@justinvyu , @christina : This older thread becomes now further relevant. I do not see dev effort, just a refined understanding of how Tune Checkpointing works.

Thanks for your help!

Hi @PhilippWillms ! Sorry this is not solved yet. Your expectations aline with mine.
Is there a related github issue with a reproduction script? That would be suitable to track and solve this. If not, could you please file one and I’ll reproduce and take further steps.

Done, I created [RLlib,Tune,AIR] Checkpointing scoring per custom metric does not work · Issue #54251 · ray-project/ray · GitHub

Thank you for surfacing this issue and creating a ticket!

1 Like