Use scheduler and search_alg with different metrics

1. Severity of the issue: (select one)
[*] High: Completely blocks me.

2. Environment:

  • Ray version: 2.47.1
  • Python version: 3.11.11
  • OS: Linux
  • Cloud/Infrastructure: Databricks jobs
  • Other libs/tools (if relevant): optuna

3. What happened vs. what you expected:

  • Expected: I want ASHAScheduler and OptunaSearch to use two different metrics (val_loss for eliminating worst trials during training, and score_oot1 for optimizing hyperparameters)

  • Actual: tune.report is unable to handle this

Details:

  • I’m using ReportCheckpointCallback for reporting the val_loss during training, which calls tune.report() inside it. I’m getting the error that “score_oot1” does not exist in the reported metrics. This is because OptunaSearch is also looking at the reported metrics after each epoch, and is throwing this error.

  • Setting TUNE_DISABLE_STRICT_METRIC_CHECKING is making all trials terminate after 1 epoch. This is because OptunaSearch is checking for score_oot1, and since it is not available, is force terminating the trial.

code snippet:

def train_model_objective(
    config, counter
):  
    X_train, X_val, X_oot = ray.get(X_train_ref), ray.get(X_val_ref), ray.get(X_oot_ref)
    y_train, y_val = ray.get(y_train_ref), ray.get(y_val_ref)

    model = create_model(config)
    callbacks = [
        EarlyStopping(monitor='val_auc', mode='max', patience=ES_PATIENCE, restore_best_weights=True),
        ReportCheckpointCallback(metrics={'val_loss':'val_loss'}, report_metrics_on='epoch_end'),
    ]
    history = model.fit(
        X_train, y_train, validation_data=(X_val, y_val),
        epochs=EPOCHS, batch_size=config['batch_size'], verbose=0,
        callbacks=callbacks, shuffle=True,
    )
    y_pred_oot = model.predict(X_oot, batch_size=4096)
    score_oot = custom_score(y_pred_oot, decisions_oot)
    tune.report({'score_oot': score_oot})

optuna_search = OptunaSearch(metric="score_oot", mode="max", study_name=f"optuna_study_{run_date}")
scheduler = ASHAScheduler(
    time_attr="training_iteration", metric="val_loss", mode="min", 
    grace_period=20, reduction_factor=1.2
)
tune_config = tune.TuneConfig(
    search_alg=optuna_search, 
    scheduler=scheduler, 
    num_samples=NUM_TRIALS, 
    max_concurrent_trials=MAX_CONC_TRIALS,
    reuse_actors=True,
)
tuner = tune.Tuner(
    train_model_objective,
    tune_config=tune_config,
    param_space=search_space,
)
results = tuner.fit()