Catching trial termination when using schedulers

Hello,

I am tuning hyper parameters using ray.tune with AsyncHyperBandScheduler scheduler.
Each trial is a training loop and ends with the evaluation of the test performances.
When trials are terminated by the scheduler, the trial is stopped and the test evaluation is skipped.

Is there a way to catch the trial termination and force the evaluation on the test set?

Here is a minimal example where I try to catch termination using try-except, which does not work. In this code, properly catching the termination of the trial should print “>>> test results computed before termination.” in the console.

How could I properly catch trial termination, so I can run the test() function, even when the trial is terminated?

Thank you for your help!

import time

from ray import tune
from ray.tune.schedulers import AsyncHyperBandScheduler

MAX_T = 50


def experiment(config, *, checkpoint_dir=None, **kwargs):
    alpha = config['alpha']
    try:
        train(alpha, MAX_T)
        test(alpha, MAX_T)
    except Exception as ex:
        # force the run of the test function,
        # even when the trial is terminated by the scheduler
        test(alpha, MAX_T, interrupted=True)
        raise ex


def test(alpha, max_t, interrupted=False):
    test_accuracy = alpha * max_t
    test_results = {'alpha': alpha, 'accuracy': test_accuracy}
    if interrupted:
        print(f">>> test results computed before termination.")


def train(alpha, max_t):
    for t in range(max_t):
        accuracy = t * alpha
        tune.report(accuracy=accuracy)
        time.sleep(1e-4)


scheduler = AsyncHyperBandScheduler(
    time_attr='training_iteration',
    max_t=MAX_T + 1,
    grace_period=1)

runner = tune.run(experiment,
                  mode='max',
                  metric='accuracy',
                  config={"alpha": tune.uniform(0, 1)},
                  scheduler=scheduler,
                  resources_per_trial={'gpu': 0, 'cpu': 1},
                  num_samples=100,
                  fail_fast=True,
                  verbose=0,
                  )

print("Best hyperparameters found were: \n", runner.best_config)

Hi @vlievin,

Each training_iteration is logically associated with a single call to tune.report. In the example code, it looks like tune.report will be called MAX_T times before getting to test.

One example of how you might be able to reframe your code is to have train and test each define training/testing for one value of MAX_T, and defining your experiment as:

for t in range(MAX_T):
    train(alpha, t)
    test(alpha, t)
    sgd.report(...)

This would intertwine training, testing, and reporting (which can trigger trial termination). However, I’m not sure if the approach I shared matches the logic that you’re expecting. Could you explain a little more about your use-case?

Hi @matthewdeng,

Thank you for your reply,

I don’t think this approach would work in my case. My case resembles:

for t in range(MAX_T):
    train(alpha, t)
    validation(alpha, t)
    # the next line potentially triggers trial termination, and skips the test phase
    sgd.report(...)

# testing - must run even if the trial is stopped before MAX_T
load_best_checkpoint()
test()

I can’t find how ray.tune interrupts the training loop, so I cannot properly catch the interruption and cannot make sure that testing is performed for each and every model (even if when they interrupted).

More about my use case: I am using ray tune couple with pytorch lightning. I am using AsyncHyperBandScheduler to accelerate the search. During search, some models are scoring well on the validation set, but are killed before being able to test them. Because of limited disk space, I cannot afford to save all models and test them afterwards.

Overall, I think that adding an easy way to catch trial termination would be a good addition to ray.tune. I could help with the implementation if given pointers on where to look at.