How severe does this issue affect your experience of using Ray?
- Low: It annoys or frustrates me for a moment.
I am in the following situation. I am training a neural density estimator that has a bunch of hyper-parameters. I would like ray tune
to stop trials that do not learn in the early epochs, and for trials that do show learning, to wait until convergence (as defined by a ray.tune.stopper.Stopper
class). The code below shows the tune.Tuner
setup that I use to achieve this goal.
For every trial that has converged, I would like to collect some information. The neural density estimator is used to compute an approximate posterior distribution, and the information I would like to collect are a bunch of metrics that show whether the posterior is properly calibrated. The question is: how can I use callbacks to collect information for every converged trial?
NB In the code below convergence means that ExperimentPlateauStopper.has_plateaud() == True
.
from ray import tune
from ray.air import session, RunConfig, Result
from ray.tune.schedulers import ASHAScheduler
from ray.tune.stopper import ExperimentPlateauStopper
def train_fn():
# this is a function that configures the neural density estimator
pass
config = {
# this is where 12 hyper parameters are specified
}
tuner = tune.Tuner(
trainable=tune.with_resources(
tune.with_parameters(train_fn),
resources={"cpu": 2.0, "gpu": gpus_per_trial}
),
tune_config=tune.TuneConfig(
search_alg=None,
metric='log_probs',
mode='max',
scheduler=ASHAScheduler(
time_attr='epoch',
max_t=100,
grace_period=1,
reduction_factor=2
),
num_samples=300,
),
run_config=RunConfig(
local_dir = './ray_logs',
name = 'test',
callbacks = None, # THIS IS WHAT I THINK I NEED TO COLLECT THE CALIBRATION INFO!
stop=ExperimentPlateauStopper(
metric='log_probs',
std=0.05,
top=10,
mode='min',
patience=3,
),
log_to_file=True,
),
param_space=config,
)
result = tuner.fit()