Lightning- Early Stopping of training in Tune

I have read this guide. In this guide, for each hyperparameter combination, it seems like Tune uses the metrics obtained by the network weights at the end of its training. However, I would like to use the network weights which yield the lowest validation score throughout training. For example, if the grid contains two hyperparameter combinations, and trains each of the two networks for 500 iterations, but the first network obtains the lowest validation score at iteration 70 and the second network obtains it at iteration 215, I want the grid search to compare the networks at their best points (iterations 70 and 215 respectively) instead of at iteration 500. For a single network, I know how to do that: use a ModelCheckpoint, and then use the best_model_path property. However, I don’t know how to make Tune do that. Can anyone help? Thank you!

Hey @Mike1, you can achieve that by configuring checkpointing in Tune to keep the best checkpoint per trial according to a metric. You can do that through the keep_checkpoints_num and checkpoint_score_attr arguments in API, or the CheckpointConfig object in the new, recommended Tuner API (available from Ray>=2.0, you can see how to use it in the latest version of the documentation - Using PyTorch Lightning with Tune — Ray 2.1.0).

Using the example I linked, you’d specify the run_config argument as:

from ray.air.config import RunConfig, CheckpointConfig

                # num_to_keep=1,  # optionally set to only keep the best checkpoint on disk/cloud

Then, when you access the checkpoints after the run through the checkpoint attribute (eg. results.get_best_result().checkpoint), you will receive the checkpoint taken at the iteration which minimized the loss.

1 Like

Thank you for the answer! However, I am not sure it does what I meant. It looks like results.get_best_result() still returns the network that got the best val loss at the end of training, not at the point where val loss was smallest, and the checkpoint returns the best val loss point for that network. For example: suppose I have two networks, net1 and net2, and:

loss(net1_at_end_of_training) < loss(net2_at_end_of_training)
loss(net1_at_best_point_during_training) > loss(net2_at_best_point_during_training)

it seems that your code returns net1_at_best_point_during_training, but I want something that returns net2_at_best_point_during_training. Any suggestions?

Got it, thanks for clarifying! In that case, you want do to:
results.get_best_result(scope="all").checkpoint - by default, get_best_result will only consider the last reported metric, but you can change the scope to consider all reports. Then, checkpoint will return the best checkpoint associated with the result.

1 Like