Callback after a trial has converged

How severe does this issue affect your experience of using Ray?

  • Low: It annoys or frustrates me for a moment.

I am in the following situation. I am training a neural density estimator that has a bunch of hyper-parameters. I would like ray tune to stop trials that do not learn in the early epochs, and for trials that do show learning, to wait until convergence (as defined by a ray.tune.stopper.Stopper class). The code below shows the tune.Tuner setup that I use to achieve this goal.

For every trial that has converged, I would like to collect some information. The neural density estimator is used to compute an approximate posterior distribution, and the information I would like to collect are a bunch of metrics that show whether the posterior is properly calibrated. The question is: how can I use callbacks to collect information for every converged trial?

NB In the code below convergence means that ExperimentPlateauStopper.has_plateaud() == True.

from ray import tune
from ray.air import session, RunConfig, Result
from ray.tune.schedulers import ASHAScheduler
from ray.tune.stopper import ExperimentPlateauStopper

def train_fn():
    # this is a function that configures the neural density estimator
    pass

config = {
    # this is where 12 hyper parameters are specified
}


tuner = tune.Tuner(
        trainable=tune.with_resources(
            tune.with_parameters(train_fn),
            resources={"cpu": 2.0, "gpu": gpus_per_trial}
        ),
        tune_config=tune.TuneConfig(
            search_alg=None, 
            metric='log_probs',
            mode='max',
            scheduler=ASHAScheduler(
                time_attr='epoch',
                max_t=100,
                grace_period=1,
                reduction_factor=2
            ),
            num_samples=300,
        ),
        run_config=RunConfig(
            local_dir =  './ray_logs',
            name = 'test',
            callbacks = None,  # THIS IS WHAT I THINK I NEED TO COLLECT THE CALIBRATION INFO!
            stop=ExperimentPlateauStopper(
                metric='log_probs',
                std=0.05, 
                top=10,
                mode='min',
                patience=3,
            ),
            log_to_file=True,
        ),
        param_space=config,
)
result = tuner.fit()

cc: @rliaw This should be moved to AIR category?

@Patrickens Have you checked out the docs and API?
https://docs.ray.io/en/latest/tune/tutorials/tune-metrics.html
https://docs.ray.io/en/latest/tune/api_docs/internals.html#tune-callbacks-docs

1 Like

Oeps, my bad! I hadn’t found the second #tune-callbacks-docs yet! That looks like what I need. Thank you for thinking along and sorry for my sloppiness.

No worries. Glad it works!

Turns out, I am not quite sure what to do yet. So with the code below, it correctly waits until a trial has converged and stops it. For every converged trial (not one that stops after the grace-period in TrialPlateauStopper or after max_t in ASHAScheduler), I would like to compute a bunch of extra metrics. As far as I understand, after a trial meets a stopping criterion, any other code in the train function (the print statement in michment_trainable) will not be executed and thus has to be done via callbacks. How can I create a callback that has access to the model created in the trainable function? Is the only way to do this via checkpointing? How would I create a checkpoint for a converged trial?

import ray
from ray import tune
from ray.air import session, RunConfig, Result
from ray.air.checkpoint import Checkpoint
from ray.tune.schedulers import ASHAScheduler
from ray.tune.stopper import TrialPlateauStopper
from ray.tune import Callback

import numpy as np
import os
import shutil
from pathlib import Path


def michment_trainable(config):
    for i in range(1000):
        log_probs = i / (i + config['Km'])
        session.report({
            "epoch": i,
            "log_probs": log_probs,
        })
    print('this wont get printed after a trial is TERMINATED')  # this is a standin line for the metrics I would like to compute  
    # THIS IS WHERE I WOULD LIKE TO EVALUATE A BUNCH OF METRICS ONCE THE TRIAL HAS CONVERGED!!
    
    
def main() -> Result:
    local_dir = os.path.join('ray_logs_test')
    tune_id = 'test'
    mode = 'max'
    metric = 'log_probs'
    
    dirpath = Path(local_dir) / tune_id
    if dirpath.exists() and dirpath.is_dir():
        shutil.rmtree(dirpath)
        
    param_space = {'Km': tune.uniform(1.0, 10.0)}

    tuner = tune.Tuner(
        trainable=tune.with_resources(
            tune.with_parameters(michment_trainable),
            resources={"cpu": 1, "gpu": 0, "memory": 1e9}
        ),
        tune_config=tune.TuneConfig(
            search_alg=None,  # NB defaults to random search, think of doing BOHB or BayesOptSearch
            metric=metric,
            mode=mode,
            scheduler=ASHAScheduler(
                time_attr='epoch',
                max_t=50,
                grace_period=1,
                reduction_factor=2
            ),
            num_samples=200,
        ),
        run_config=RunConfig(
            local_dir = local_dir,
            name = tune_id,
            callbacks = None,
            stop = TrialPlateauStopper(
                metric=metric,
                std=0.1, 
                num_results=10,
                mode='max',
                grace_period=10,
            ),
            log_to_file=True,
        ),
        param_space=param_space,
    )
    result = tuner.fit()
    return result
    
result = main()

How can I create a callback that has access to the model created in the trainable function?

Hey @Patrickens, does this sort of approach work for you? It should let you access the model after each trial completes.

class CustomCallback(Callback):
    def on_trial_complete(self, iteration, trials, trial):
        checkpoint = trial.checkpoint.to_air_checkpoint()
        model = checkpoint.to_dict()["model"]
        ... # Collect metrics about model here.


def michment_trainable(config):
    for i in range(1000):
        log_probs = i / (i + config["Km"])
        checkpoint = Checkpoint.from_dict({"model": ...})
        session.report(
            {
                "epoch": i,
                "log_probs": log_probs,
            },
            checkpoint=checkpoint,
        )
    ...

def main() -> Result:
    ...

    tuner = tune.Tuner(
        ...
        run_config=RunConfig(
            ...            
            callbacks=[CustomCallback()],
            ...
        ),
        ...
    )
    result = tuner.fit()
    return result


result = main()

Is the only way to do this via checkpointing?

I think so. Even if there are other approaches, I figure checkpointing is the cleanest approach.

How would I create a checkpoint for a converged trial?

on_trial_complete is called for every completed trial, including the ones that stopped early. As far as I know, there’s no way to know if a trial stopped because of convergence or early stopping. cc @kai @justinvyu do know of any way to achieve this?

1 Like

For every converged trial (not one that stops after the grace-period in TrialPlateauStopper or after max_t in ASHAScheduler), …

Actually @Patrickens, could you tell me more about what you mean by this? Like, are you referring to trials that reach 1000 epochs? Or trials that plateau? Or something else?

I mean that I dont want to run the call-back for Trials that terminated after satisfying only the grace-period and are then terminated or that ran to 1000 epochs without converging. This is because the callback is a very expensive operation.

Basically, I am training a model that generates posterior distributions. In the callback, I want to assess the quality of these distributions using some calibration metrics (read about the metrics here Simulation-based calibration - sbi).

Running the calibration metrics takes some minutes, so I would like to run this as little as possible.