Bug in tuner.restore / optuna_search

Hi!
I am using ray 2.5.0

Problem:
Ray/Optuna does not seem to restore errored trials correctly. I think this issue comes up for trials that errored, but have not reached the number of max_failures. The result is that one gets a keyerror like so:

  File ".../lib/python3.8/site-packages/ray/tune/search/optuna/optuna_search.py", line 533, in on_trial_complete
    ot_trial = self._ot_trials[trial_id]
KeyError: '138dae89'

To Reproduce:

  1. Run the following script like so:
python3 raytest.py my_experiment_folder
  1. Stop the script after there have been failed trials
  2. Run the script like so:
python3 raytest.py my_experiment_folder --restore

Since I have an experiment that has been running for quite a while, which I can’t restore I am interested in a hot fix for restoring and a hot fix for preventing the issue in the first place.

Here is raytest.py:

import ray
from ray import tune, air
from ray.air import session
from ray.tune.search.optuna import OptunaSearch
from optuna.samplers import MOTPESampler
import argparse


def trainable(config):
    if config['init_learning_rate'] == 0.0001:
        raise Exception  # force a failed trail

    metrics = {
        'accuracy': 0.99,
        'macs': 123
    }
    session.report(metrics)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Test restore bug.')
    parser.add_argument('save_dir', type=str, help="where to save experiment")
    parser.add_argument("--restore", action="store_true", help="use if you want to restore an experiment from save_dir")
    args = parser.parse_args()

    # define search
    sampler = MOTPESampler(n_startup_trials=10)
    search_alg = OptunaSearch(points_to_evaluate=[], sampler=sampler, metric=["accuracy", "macs"],
                              mode=["max", "min"])

    tune_config = tune.TuneConfig(num_samples=1000,
                                  search_alg=search_alg)

    failure_config = air.FailureConfig(
        max_failures=2,
        fail_fast=False
    )
    run_config = air.RunConfig(name="restorebug",
                               storage_path=args.save_dir,
                               verbose=3,
                               failure_config=failure_config)

    config = {
        'init_learning_rate': tune.choice([0.00001, 0.0001, 0.001, 0.01, 0.1])
    }

    tuner = tune.Tuner(trainable, param_space=config, tune_config=tune_config, run_config=run_config)

    if not args.restore:
        results = tuner.fit()
    else:
        tuner = tune.Tuner.restore(args.save_dir + "restorebug",
                                   trainable=trainable,
                                   restart_errored=True, resume_unfinished=True)
        tuner.fit()

I have encountered the same problem. Any solutions?