Trial with unexpected good status encountered: PENDING

Hello there,
I received the error message “ray.tune.error.TuneError: Trial with unexpected good status encountered: PENDING” after my BOHB script was working fine for more than an hour. I’ll provide some reduced code here to show the main part of the project:

imports ...


class TrainableNN(Trainable):
    def setup(self, config, args, ds, ds_info):
        self.timestep = 0
        self.config = config
        self.args = args

        self.train_ds, self.val_ds = split_datasets(args, ds)
        
        torch.manual_seed(args.random_seed)
        if args.model.name == 'resnet18_deterministic':
            self.model = deterministic.resnet.ResNet18(ds_info['n_classes'])
            self.criterion = nn.CrossEntropyLoss()
            self.optimizer = torch.optim.SGD(
                self.model.parameters(),
                lr=config['learning_rate'],
                weight_decay=config['weight_decay'],
                momentum=args.model.optimizer.momentum,
                nesterov=True,
            )
            self.lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, T_max=args.model.n_epochs)
        self.device = args.device

        self.train_loader = torch.utils.data.DataLoader(self.train_ds, batch_size=config['batch_size'], shuffle=True, drop_last=True)
        self.val_loader = torch.utils.data.DataLoader(self.val_ds, batch_size=64, shuffle=False, drop_last=False)


    def step(self):
        for i in range(self.args.step_size):
            _ = self.train_one_epoch(self.train_loader)
            self.lr_scheduler.step()
        self.timestep += 1
        val_stats = self.evaluate(self.val_loader)
        return val_stats
    

    def train_one_epoch(self, dataloader, epoch=None, print_freq=200):
        # Just trains one epoch in a classical pytorch manner

    @torch.no_grad()
    def evaluate(self, dataloader, dataloaders_ood=None):
        # Just evaluates the current models performance
    
    @torch.no_grad()
    def collect_predictions(self, dataloader):
        all_logits = []
        all_targets = []
        for inputs, targets in dataloader:
            inputs = inputs.to(self.device)
            all_logits.append(self.model(inputs).cpu())
            all_targets.append(targets)
        logits = torch.cat(all_logits, dim=0)
        targets = torch.cat(all_targets, dim=0)
        return logits, targets


    def save_checkpoint(self, checkpoint_dir):
        checkpoint_path = os.path.join(checkpoint_dir, "model.pth")
        torch.save(self.model.state_dict(), checkpoint_path)
        return checkpoint_path

    def load_checkpoint(self, checkpoint_path):
        self.model.load_state_dict(torch.load(checkpoint_path))

    @classmethod
    def default_resource_request(cls, config):
        return PlacementGroupFactory([{"CPU": 4, "GPU": 0.5}])


@hydra.main(version_base=None, config_path="./configs", config_name="hparam_search")
def main(args):
    logger = logging.getLogger()
    logger.info('Using setup: %s', args)

    n_iterations = args.model.n_epochs//args.step_size

    # Init ray, if we are using slurm, set cpu and gpus
    adress = 'auto' if args.distributed else None
    num_cpus = int(os.environ.get('SLURM_CPUS_PER_TASK', args.cpus_per_trial))
    num_gpus = torch.cuda.device_count()
    ray.init(address=adress, num_cpus=num_cpus, num_gpus=num_gpus)

    search_space, points_to_evaluate = build_search_space(args)
    bohb_search = TuneBOHB(
        points_to_evaluate=points_to_evaluate,
        metric="test_acc1", 
        mode="max"
    )
    bohb_search = tune.search.ConcurrencyLimiter(bohb_search, max_concurrent=args.max_concurrent)
    bohb_hyperband = HyperBandForBOHB(
        time_attr="training_iteration",
        max_t=n_iterations,
        reduction_factor=args.reduction_factor,
        stop_last_trials=False,
    )

    # Init dset used to give to the Ray Object store
    ds, ds_info = datasets.cifar.build_cifar10('train', args.dataset_path, return_info=True)

    tuner = tune.Tuner(
        tune.with_parameters(TrainableNN, args=args, ds=ds, ds_info=ds_info),
        run_config=air.RunConfig(
            stop={
                "training_iteration": n_iterations,
            },
        ),
        tune_config=tune.TuneConfig(
            search_alg=bohb_search,
            scheduler=bohb_hyperband,
            num_samples=args.n_configs,
            metric="test_acc1", 
            mode="max"
        ),
        param_space=search_space
        )

    results = tuner.fit()    

    print('Best NLL Stats: {}'.format(results.get_best_result().metrics))
    print('Best NLL Hyperparameter: {}'.format(results.get_best_result().config))
    print('Best Acc Hyperparameter: {}'.format(results.get_best_result(metric="test_acc1", mode="max").config))


def build_search_space(args):
    points_to_evaluate = None
    if args.model.name == 'resnet18_deterministic':
        search_space = {
            "learning_rate": tune.uniform(0, .1),
            "weight_decay": tune.uniform(0, .1),
            "batch_size": tune.choice([32, 64, 128])
        }
        points_to_evaluate = [
            {"learning_rate": 1e-1, "weight_decay": 5e-4, "batch_size":64},
            {"learning_rate": 1e-2, "weight_decay": 5e-4, "batch_size":64}
        ]
    else:
        raise NotImplementedError('Model {} not implemented.'.format(args.model.name))
    return search_space, points_to_evaluate


if __name__ == '__main__':
    main()

Here is the resulting Stack Trace

Traceback (most recent call last):
  File "/mnt/stud/home/phahn/.conda/envs/uncertainty_evaluation/lib/python3.9/site-packages/ray/tune/tuner.py", line 367, in fit
    return self._local_tuner.fit()
  File "/mnt/stud/home/phahn/.conda/envs/uncertainty_evaluation/lib/python3.9/site-packages/ray/tune/impl/tuner_internal.py", line 503, in fit
    analysis = self._fit_internal(trainable, param_space)
  File "/mnt/stud/home/phahn/.conda/envs/uncertainty_evaluation/lib/python3.9/site-packages/ray/tune/impl/tuner_internal.py", line 621, in _fit_internal
    analysis = run(
  File "/mnt/stud/home/phahn/.conda/envs/uncertainty_evaluation/lib/python3.9/site-packages/ray/tune/tune.py", line 904, in run
    runner.step()
  File "/mnt/stud/home/phahn/.conda/envs/uncertainty_evaluation/lib/python3.9/site-packages/ray/tune/execution/trial_runner.py", line 1342, in step
    self._wait_and_handle_event(next_trial)
  File "/mnt/stud/home/phahn/.conda/envs/uncertainty_evaluation/lib/python3.9/site-packages/ray/tune/execution/trial_runner.py", line 1411, in _wait_and_handle_event
    raise TuneError(traceback.format_exc())
ray.tune.error.TuneError: Traceback (most recent call last):
  File "/mnt/stud/home/phahn/.conda/envs/uncertainty_evaluation/lib/python3.9/site-packages/ray/tune/execution/trial_runner.py", line 1400, in _wait_and_handle_event
    self._on_training_result(
  File "/mnt/stud/home/phahn/.conda/envs/uncertainty_evaluation/lib/python3.9/site-packages/ray/tune/execution/trial_runner.py", line 694, in _on_training_result
    self._process_trial_results(trial, result)
  File "/mnt/stud/home/phahn/.conda/envs/uncertainty_evaluation/lib/python3.9/site-packages/ray/tune/execution/trial_runner.py", line 707, in _process_trial_results
    decision = self._process_trial_result(trial, result)
  File "/mnt/stud/home/phahn/.conda/envs/uncertainty_evaluation/lib/python3.9/site-packages/ray/tune/execution/trial_runner.py", line 750, in _process_trial_result
    decision = self._scheduler_alg.on_trial_result(
  File "/mnt/stud/home/phahn/.conda/envs/uncertainty_evaluation/lib/python3.9/site-packages/ray/tune/schedulers/hb_bohb.py", line 114, in on_trial_result
    action = self._process_bracket(trial_runner, bracket)
  File "/mnt/stud/home/phahn/.conda/envs/uncertainty_evaluation/lib/python3.9/site-packages/ray/tune/schedulers/hyperband.py", line 270, in _process_bracket
    raise TuneError(
ray.tune.error.TuneError: Trial with unexpected good status encountered: PENDING

I was trying to run this code on a server using 4 GPUs and 40 CPUs to have 8 parallel runs as well as some additional cpu power if necessary. Is there anything i am missing or is this a bug from rays side?

Additional Note:
Here is the last Status Update i received before the error occured:

== Status ==
Current time: 2023-05-04 13:12:05 (running for 02:48:57.77)
Using HyperBand: num_stopped=26 total_brackets=4
Round #0:
  Bracket(Max Size (n)=1, Milestone (r)=12, completed=56.2%): {RUNNING: 1, TERMINATED: 15}
  Bracket(Max Size (n)=2, Milestone (r)=12, completed=49.0%): {RUNNING: 2, TERMINATED: 8}
  Bracket(Max Size (n)=4, Milestone (r)=10, completed=37.1%): {RUNNING: 4, TERMINATED: 3}
  Bracket(Max Size (n)=5, Milestone (r)=10, completed=61.3%): {PAUSED: 4, RUNNING: 1}
Logical resource usage: 32.0/40 CPUs, 4.0/4 GPUs (0.0/1.0 accelerator_type:V100)
Current best trial: 1e952b2e with test_acc1=92.30999755859375 and parameters={'learning_rate': 0.007364477289920681, 'weight_decay': 0.00026112245478438644, 'batch_size': 64}
Result logdir: /mnt/stud/home/phahn/ray_results/TrainableNN_2023-05-04_10-23-06
Number of trials: 38/100 (4 PAUSED, 8 RUNNING, 26 TERMINATED)
+----------------------+------------+---------------------+--------------+-----------------+----------------+--------+------------------+-------------+-------------+------------+
| Trial name           | status     | loc                 |   batch_size |   learning_rate |   weight_decay |   iter |   total time (s) |   test_loss |   test_acc1 |   test_nll |
|----------------------+------------+---------------------+--------------+-----------------+----------------+--------+------------------+-------------+-------------+------------|
| TrainableNN_271901b2 | RUNNING    | 141.51.131.93:24944 |           64 |     0.01        |    0.0005      |     13 |         7455.04  |    0.311992 |       91.35 |   0.311992 |
| TrainableNN_27b322fa | RUNNING    | 141.51.131.93:39254 |          128 |     0.0079678   |    0.00207703  |     11 |         3904.6   |    0.350977 |       89.23 |   0.350977 |
| TrainableNN_33aa0129 | RUNNING    | 141.51.131.93:17092 |           64 |     0.0116685   |    0.00279521  |      5 |         1793.52  |    0.507012 |       82.63 |   0.507012 |
| TrainableNN_51b3eb29 | RUNNING    | 141.51.131.93:30516 |           64 |     0.000517159 |    0.0012137   |     10 |         5734     |    0.393945 |       89.68 |   0.393945 |
| TrainableNN_53b90488 | RUNNING    | 141.51.131.93:28080 |           64 |     0.00197385  |    0.00236213  |      5 |         1795.03  |    0.328827 |       89.65 |   0.328827 |
| TrainableNN_7376e736 | RUNNING    | 141.51.131.93:37382 |           64 |     0.0342845   |    0.00132928  |      8 |         5478.45  |    0.65188  |       79.13 |   0.65188  |
| TrainableNN_7aef5d6d | RUNNING    | 141.51.131.93:42069 |           64 |     0.0095709   |    0.00216552  |      6 |         2146.83  |    0.451911 |       84.87 |   0.451911 |
| TrainableNN_fff989ea | RUNNING    | 141.51.131.93:32824 |          128 |     0.0269434   |    0.0249693   |      9 |         3156.96  |    1.49307  |       49.31 |   1.49307  |
| TrainableNN_039f9d2e | PAUSED     | 141.51.131.93:34033 |           32 |     0.0103221   |    0.0833994   |     10 |         3221.64  |    2.30294  |        9.6  |   2.30294  |
| TrainableNN_1e952b2e | PAUSED     | 141.51.131.93:29038 |           64 |     0.00736448  |    0.000261122 |     10 |         3735.09  |    0.327314 |       92.31 |   0.327314 |
| TrainableNN_ac0224fa | PAUSED     | 141.51.131.93:31947 |           32 |     0.000788363 |    0.0581641   |     10 |         3755.57  |    0.683219 |       80.67 |   0.683219 |
| TrainableNN_e41809c7 | PAUSED     | 141.51.131.93:32497 |           32 |     0.0901537   |    0.00247578  |     10 |         3213.99  |    1.39476  |       54.68 |   1.39476  |
| TrainableNN_17b308f3 | TERMINATED | 141.51.131.93:15765 |          128 |     0.0550693   |    0.0771374   |      2 |         1196.65  |    2.30409  |        9.86 |   2.30409  |
| TrainableNN_24cc2be6 | TERMINATED | 141.51.131.93:9284  |           64 |     0.0554148   |    0.0373097   |      1 |          353.842 |    2.07382  |       20.12 |   2.07382  |
| TrainableNN_24ec6a67 | TERMINATED | 141.51.131.93:17330 |           64 |     0.00109714  |    0.0757342   |      4 |         1467.55  |    0.843984 |       77.18 |   0.843984 |
| TrainableNN_25200616 | TERMINATED | 141.51.131.93:27100 |           64 |     0.065716    |    0.0965496   |      5 |         1805.84  |    2.30316  |        9.94 |   2.30316  |
| TrainableNN_34b86efe | TERMINATED | 141.51.131.93:27872 |           64 |     0.080713    |    0.0332946   |      5 |         3375.9   |    2.30384  |        9.94 |   2.30384  |
| TrainableNN_38d09b83 | TERMINATED | 141.51.131.93:11218 |           64 |     0.00253526  |    0.0677027   |      2 |         1409.73  |    1.01034  |       70.44 |   1.01034  |
| TrainableNN_43b48ec7 | TERMINATED | 141.51.131.93:17548 |           64 |     0.1         |    0.0005      |      4 |         2525.76  |    0.607998 |       79.21 |   0.607998 |
| TrainableNN_48df4487 | TERMINATED | 141.51.131.93:10732 |           32 |     0.0898256   |    0.0337691   |      1 |          343.068 |    2.3049   |        9.31 |   2.3049   |
+----------------------+------------+---------------------+--------------+-----------------+----------------+--------+------------------+-------------+-------------+------------+
... 18 more trials not shown (18 TERMINATED)

Hi @nhaH-luaP, are you running on Ray nightly? What’s your Ray version?

hey @justinvyu,
My ray version is 2.4.0 and as far as i m aware i m not using any nightly Ray wheels. To install i used the commands:

pip install -U "ray[default]"
pip install -U "ray[tune]"

and

pip install -U "ray[air]"

Got it, this may be a bug with BOHB, I will try to look into this shortly and report back. cc: @kai

1 Like

@nhaH-luaP Did you enable the new Tune execution backend with the "TUNE_NEW_EXECUTION" environment variable? I believe this issue is related to a bug fixed here: [air/execution] Fix new execution backend for BOHB by krfricke · Pull Request #34828 · ray-project/ray · GitHub. Could you try running with ray nightly to see if the issue still exists?

Basically, trials paused by BOHB were getting started (set to PENDING) pre-maturely, which could lead to an issue when the hyperband scheduler processes the bracket.

@justinvyu thanks for the suggestion.

I did as you said, installed the nightlies (for Linux (x86_64) and python 3.9) and set the env_var at the start of my main. After running for approx 30 min. a new kind of error occured, see below:

== Status ==
Current time: 2023-05-11 11:14:53 (running for 00:32:16.09)
Using HyperBand: num_stopped=12 total_brackets=3
Round #0:
  Bracket(Max Size (n)=4, Milestone (r)=4, completed=30.0%): {PENDING: 4, TERMINATED: 12}
  Bracket(Max Size (n)=10, Milestone (r)=2, completed=14.6%): {PAUSED: 6, RUNNING: 4}
  Bracket(Max Size (n)=7, Milestone (r)=5, completed=0.0%): {RUNNING: 4}
Logical resource usage: 32.0/40 CPUs, 4.0/4 GPUs (0.0/1.0 accelerator_type:V100)
Current best trial: 116393b0 with test_acc1=87.61000061035156 and parameters={'learning_rate': 0.01, 'weight_decay': 0.0005, 'batch_size': 64}
Result logdir: /mnt/stud/home/phahn/ray_results/TrainableNN_2023-05-11_10-42-37
Number of trials: 30/100 (6 PAUSED, 4 PENDING, 8 RUNNING, 12 TERMINATED)
+----------------------+------------+-----------------------+--------------+-----------------+----------------+--------+------------------+-------------+-------------+------------+
| Trial name           | status     | loc                   |   batch_size |   learning_rate |   weight_decay |   iter |   total time (s) |   test_loss |   test_acc1 |   test_nll |
|----------------------+------------+-----------------------+--------------+-----------------+----------------+--------+------------------+-------------+-------------+------------|
| TrainableNN_0a2818ed | RUNNING    | 141.51.131.91:1622207 |           64 |     0.024319    |     0.0562279  |        |                  |             |             |            |
| TrainableNN_179a7282 | RUNNING    | 141.51.131.91:1621510 |           32 |     0.000714326 |     0.00561798 |        |                  |             |             |            |
| TrainableNN_51f27e50 | RUNNING    | 141.51.131.91:1622423 |           64 |     0.00428947  |     0.0842709  |        |                  |             |             |            |
| TrainableNN_627145bc | RUNNING    | 141.51.131.91:1621983 |           64 |     0.00228726  |     0.0011372  |        |                  |             |             |            |
| TrainableNN_634e3c38 | RUNNING    | 141.51.131.91:1622514 |           32 |     0.00735346  |     0.00326644 |        |                  |             |             |            |
| TrainableNN_847f84c2 | RUNNING    | 141.51.131.91:1622335 |           32 |     0.0615403   |     0.0470682  |        |                  |             |             |            |
| TrainableNN_86f9352d | PAUSED     | 141.51.131.91:1617748 |           32 |     0.0833688   |     0.0442557  |      2 |          744.592 |    2.30598  |        9.99 |   2.30598  |
| TrainableNN_b163ff01 | PAUSED     | 141.51.131.91:1617532 |          128 |     0.097284    |     0.00458405 |      2 |          720.957 |    1.36935  |       55.12 |   1.36935  |
| TrainableNN_bbc0dfca | PAUSED     | 141.51.131.91:1617937 |           64 |     0.0962862   |     0.0234678  |      2 |          735.798 |    2.30484  |       10.59 |   2.30484  |
| TrainableNN_cfe3ae03 | PAUSED     | 141.51.131.91:1618211 |           64 |     0.0104651   |     0.0904184  |      2 |          687.52  |    2.53662  |       19.8  |   2.53662  |
| TrainableNN_d5356cd0 | PAUSED     | 141.51.131.91:1617633 |           64 |     0.00954714  |     0.0282319  |      2 |          747.867 |    1.23099  |       57.3  |   1.23099  |
| TrainableNN_f1292d20 | PAUSED     | 141.51.131.91:1617813 |           32 |     0.0121753   |     0.00257511 |      2 |          748.674 |    0.649187 |       78.24 |   0.649187 |
| TrainableNN_116393b0 | PENDING    | 141.51.131.91:1619749 |           64 |     0.01        |     0.0005     |      2 |          732.109 |    0.384798 |       87.61 |   0.384798 |
| TrainableNN_d801bc47 | PENDING    | 141.51.131.91:1618285 |           64 |     0.1         |     0.0005     |      2 |          718.29  |    0.725444 |       75.72 |   0.725444 |
| TrainableNN_e3c3f9e0 | PENDING    | 141.51.131.91:1620764 |          128 |     0.00537368  |     0.0548603  |      2 |          662.116 |    1.49474  |       52.15 |   1.49474  |
| TrainableNN_f95c1f87 | PENDING    | 141.51.131.91:1620946 |          128 |     0.00106557  |     0.0969766  |      2 |          706.713 |    0.844401 |       77.96 |   0.844401 |
| TrainableNN_27f57b91 | TERMINATED | 141.51.131.91:1616613 |          128 |     0.0797937   |     0.04606    |      1 |          382.141 |    2.15555  |       16.52 |   2.15555  |
| TrainableNN_487d0ae3 | TERMINATED | 141.51.131.91:1620631 |           32 |     0.0129602   |     0.0527052  |      2 |          738.71  |    2.62848  |        9.94 |   2.62848  |
| TrainableNN_5d4392f7 | TERMINATED | 141.51.131.91:1614546 |          128 |     0.0659099   |     0.0586861  |      1 |          377.781 |    2.30489  |        9.88 |   2.30489  |
| TrainableNN_65e368bf | TERMINATED | 141.51.131.91:1614457 |           64 |     0.037933    |     0.0349821  |      1 |          406.825 |    2.14053  |       18.27 |   2.14053  |
| TrainableNN_6ad37eea | TERMINATED | 141.51.131.91:1620339 |          128 |     0.0454381   |     0.0199769  |      2 |          737.796 |    2.56409  |       29.34 |   2.56409  |
| TrainableNN_a6b75104 | TERMINATED | 141.51.131.91:1620559 |           64 |     0.047798    |     0.0162622  |      2 |          797.612 |    3.96316  |       17.14 |   3.96316  |
+----------------------+------------+-----------------------+--------------+-----------------+----------------+--------+------------------+-------------+-------------+------------+
... 10 more trials not shown (2 RUNNING, 6 TERMINATED)


(TrainableNN pid=1621510)   Train:  Total time: 0:00:36 [repeated 3x across cluster]
(TrainableNN pid=1621510)   Train:  [   0/1250] eta: 0:00:29 lr: 0.0007107630034366765 loss: 0.5052 (0.5052) acc1: 78.1250 (78.1250) time: 0.0238 data: 0.0096 max mem: 317 [repeated 3x across cluster]
(TrainableNN pid=1622101)   Train:  [600/625] eta: 0:00:01 lr: 0.0017423756436208842 loss: 0.9568 (1.0319) acc1: 65.6250 (63.2514) time: 0.0533 data: 0.0165 max mem: 587 [repeated 4x across cluster]
(TrainableNN pid=1622101)   Train:  [ 400/1250] eta: 0:00:25 lr: 0.061536538844397505 loss: 2.2165 (2.2453) acc1: 12.5000 (14.3547) time: 0.0263 data: 0.0088 max mem: 317
(TrainableNN pid=1622335)   Train:  [ 400/1250] eta: 0:00:25 lr: 0.061536538844397505 loss: 2.2165 (2.2453) acc1: 12.5000 (14.3547) time: 0.0263 data: 0.0088 max mem: 317 [repeated 5x across cluster]
(TrainableNN pid=1622207)   Train:  [600/625] eta: 0:00:01 lr: 0.02431750893607888 loss: 1.7992 (1.8131) acc1: 31.2500 (31.4398) time: 0.0525 data: 0.0194 max mem: 587
(TrainableNN pid=1622207)   Train:  Total time: 0:00:33
Error executing job with overrides: ['dataset=CIFAR10', 'ood_datasets=[SVHN]', 'model=resnet18', 'output_dir=/mnt/stud/work/phahn/uncertainty/output/HP-Search_14519/', 'random_seed=1', 'cpus_per_trial=4', 'gpus_per_trial=1']
Traceback (most recent call last):
  File "/mnt/stud/home/phahn/.conda/envs/uncertainty_evaluation/lib/python3.9/site-packages/ray/tune/execution/tune_controller.py", line 782, in _on_result
    on_result(trial, *args, **kwargs)
  File "/mnt/stud/home/phahn/.conda/envs/uncertainty_evaluation/lib/python3.9/site-packages/ray/tune/execution/trial_runner.py", line 694, in _on_training_result
    self._process_trial_results(trial, result)
  File "/mnt/stud/home/phahn/.conda/envs/uncertainty_evaluation/lib/python3.9/site-packages/ray/tune/execution/trial_runner.py", line 707, in _process_trial_results
    decision = self._process_trial_result(trial, result)
  File "/mnt/stud/home/phahn/.conda/envs/uncertainty_evaluation/lib/python3.9/site-packages/ray/tune/execution/trial_runner.py", line 750, in _process_trial_result
    decision = self._scheduler_alg.on_trial_result(
  File "/mnt/stud/home/phahn/.conda/envs/uncertainty_evaluation/lib/python3.9/site-packages/ray/tune/schedulers/hb_bohb.py", line 112, in on_trial_result
    trial_runner._search_alg.searcher.on_pause(trial.trial_id)
  File "/mnt/stud/home/phahn/.conda/envs/uncertainty_evaluation/lib/python3.9/site-packages/ray/tune/search/concurrency_limiter.py", line 174, in on_pause
    self.searcher.on_pause(trial_id)
  File "/mnt/stud/home/phahn/.conda/envs/uncertainty_evaluation/lib/python3.9/site-packages/ray/tune/search/bohb/bohb_search.py", line 275, in on_pause
    self.running.remove(trial_id)
KeyError: 'f23e8d5f'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/mnt/stud/home/phahn/.conda/envs/uncertainty_evaluation/lib/python3.9/site-packages/ray/tune/tuner.py", line 367, in fit
    return self._local_tuner.fit()
  File "/mnt/stud/home/phahn/.conda/envs/uncertainty_evaluation/lib/python3.9/site-packages/ray/tune/impl/tuner_internal.py", line 503, in fit
    analysis = self._fit_internal(trainable, param_space)
  File "/mnt/stud/home/phahn/.conda/envs/uncertainty_evaluation/lib/python3.9/site-packages/ray/tune/impl/tuner_internal.py", line 621, in _fit_internal
    analysis = run(
  File "/mnt/stud/home/phahn/.conda/envs/uncertainty_evaluation/lib/python3.9/site-packages/ray/tune/tune.py", line 904, in run
    runner.step()
  File "/mnt/stud/home/phahn/.conda/envs/uncertainty_evaluation/lib/python3.9/site-packages/ray/tune/execution/tune_controller.py", line 247, in step
    if not self._actor_manager.next(timeout=0.1):
  File "/mnt/stud/home/phahn/.conda/envs/uncertainty_evaluation/lib/python3.9/site-packages/ray/air/execution/_internal/actor_manager.py", line 224, in next
    self._actor_task_events.resolve_future(future)
  File "/mnt/stud/home/phahn/.conda/envs/uncertainty_evaluation/lib/python3.9/site-packages/ray/air/execution/_internal/event_manager.py", line 118, in resolve_future
    on_result(result)
  File "/mnt/stud/home/phahn/.conda/envs/uncertainty_evaluation/lib/python3.9/site-packages/ray/air/execution/_internal/actor_manager.py", line 748, in on_result
    self._actor_task_resolved(
  File "/mnt/stud/home/phahn/.conda/envs/uncertainty_evaluation/lib/python3.9/site-packages/ray/air/execution/_internal/actor_manager.py", line 300, in _actor_task_resolved
    tracked_actor_task._on_result(tracked_actor, result)
  File "/mnt/stud/home/phahn/.conda/envs/uncertainty_evaluation/lib/python3.9/site-packages/ray/tune/execution/tune_controller.py", line 791, in _on_result
    raise TuneError(traceback.format_exc())
ray.tune.error.TuneError: Traceback (most recent call last):
  File "/mnt/stud/home/phahn/.conda/envs/uncertainty_evaluation/lib/python3.9/site-packages/ray/tune/execution/tune_controller.py", line 782, in _on_result
    on_result(trial, *args, **kwargs)
  File "/mnt/stud/home/phahn/.conda/envs/uncertainty_evaluation/lib/python3.9/site-packages/ray/tune/execution/trial_runner.py", line 694, in _on_training_result
    self._process_trial_results(trial, result)
  File "/mnt/stud/home/phahn/.conda/envs/uncertainty_evaluation/lib/python3.9/site-packages/ray/tune/execution/trial_runner.py", line 707, in _process_trial_results
    decision = self._process_trial_result(trial, result)
  File "/mnt/stud/home/phahn/.conda/envs/uncertainty_evaluation/lib/python3.9/site-packages/ray/tune/execution/trial_runner.py", line 750, in _process_trial_result
    decision = self._scheduler_alg.on_trial_result(
  File "/mnt/stud/home/phahn/.conda/envs/uncertainty_evaluation/lib/python3.9/site-packages/ray/tune/schedulers/hb_bohb.py", line 112, in on_trial_result
    trial_runner._search_alg.searcher.on_pause(trial.trial_id)
  File "/mnt/stud/home/phahn/.conda/envs/uncertainty_evaluation/lib/python3.9/site-packages/ray/tune/search/concurrency_limiter.py", line 174, in on_pause
    self.searcher.on_pause(trial_id)
  File "/mnt/stud/home/phahn/.conda/envs/uncertainty_evaluation/lib/python3.9/site-packages/ray/tune/search/bohb/bohb_search.py", line 275, in on_pause
    self.running.remove(trial_id)
KeyError: 'f23e8d5f'

Is there any chance the on_pause method is called multiple times? Sometimes during logging i get a notification behind my train_stats in this form

(TrainableNN pid=1622335)   Train:  [ 400/1250] eta: 0:00:25 lr: 0.061536538844397505 loss: 2.2165 (2.2453) acc1: 12.5000 (14.3547) time: 0.0263 data: 0.0088 max mem: 317 [repeated 5x across cluster]

where i am not quite sure what this repeated 5x across cluster means. I could only guess that if trying to remove the same trial id twice would lead to some kind of the above mentioned error…

Additionally, to make sure you have every information necessary and i did not make a mistake somewhere else, here is the slurm script i used to run this trial

#!/usr/bin/zsh
#SBATCH --mem=32gb
#SBATCH --ntasks=1
#SBATCH --cpus-per-task=40
#SBATCH --gres=gpu:4
#SBATCH --partition=main
#SBATCH --job-name=HP-Search
#SBATCH --output=/mnt/stud/work/phahn/uncertainty/logs/%x_%j.log
source /mnt/stud/home/phahn/.zshrc

rm /mnt/stud/work/phahn/uncertainty/uncertainty-evaluation/.git/index.lock

git checkout 52-hparam-optimization-for-uncertainty

conda activate uncertainty_evaluation

cd /mnt/stud/work/phahn/uncertainty/uncertainty-evaluation/experiments/uncertainty

OUTPUT_DIR=/mnt/stud/work/phahn/uncertainty/output/${SLURM_JOB_NAME}_${SLURM_JOB_ID}/
echo "Saving results to $OUTPUT_DIR"

srun python -u hparam_search.py \
    dataset=CIFAR10 \
    ood_datasets=\[SVHN\] \
    model=resnet18 \
    output_dir=$OUTPUT_DIR \
    random_seed=1 \
    cpus_per_trial=4 \
    gpus_per_trial=1

where hparam_search.py contains the code demonstrated in my original post.

Thanks for the info, I’ll try reproducing this with a simplified training script – basically just isolating the scheduling behavior of BOHB.

I found the time to create one myself, if it helps:

import os

import ray
import ray.tune as tune

import torch

import ray
from ray import tune
from ray.tune import Trainable
from ray import air
from ray.tune.schedulers.hb_bohb import HyperBandForBOHB
from ray.tune.search.bohb import TuneBOHB
from ray.tune.execution.placement_groups import PlacementGroupFactory

import time


class TrainableNN(Trainable):
    def setup(self, config):
        self.timestep = 0

    def step(self):
        time.sleep(2)
        return {'val_acc1':torch.randint(1, 100, (1,)).item()}

    def save_checkpoint(self, checkpoint_dir):
        checkpoint_path = os.path.join(checkpoint_dir, "model.pth")
        torch.save({}, checkpoint_path)
        return checkpoint_path

    def load_checkpoint(self, checkpoint_path):
        loaded_stuff = torch.load(checkpoint_path)
        print("do nothing here")

    @classmethod
    def default_resource_request(cls, config):
        return PlacementGroupFactory([{"CPU": 1, "GPU":0.25}])


def main():
    os.environ["TUNE_NEW_EXECUTION"] = "1"
    n_iterations = 4

    # Init ray, if we are using slurm, set cpu and gpus
    adress = None
    num_cpus = int(os.environ.get('SLURM_CPUS_PER_TASK', 16))
    num_gpus = torch.cuda.device_count()
    ray.init(address=adress, num_cpus=num_cpus, num_gpus=num_gpus)

    search_space = {
        'placeholder':tune.uniform(0, 1)
        }
    bohb_search = TuneBOHB(
        metric="val_acc1", 
        mode="max"
    )
    bohb_search = tune.search.ConcurrencyLimiter(bohb_search, max_concurrent=8)
    bohb_hyperband = HyperBandForBOHB(
        time_attr="training_iteration",
        max_t=n_iterations,
        reduction_factor=2,
        stop_last_trials=False,
    )

    tuner = tune.Tuner(
        tune.with_parameters(TrainableNN),
        run_config=air.RunConfig(
            stop={
                "training_iteration": n_iterations,
            },
        ),
        tune_config=tune.TuneConfig(
            search_alg=bohb_search,
            scheduler=bohb_hyperband,
            num_samples=100,
            metric="val_acc1", 
            mode="max"
        ),
        param_space=search_space
        )

    results = tuner.fit()    

    print('Best HP: {}'.format(results.get_best_result().metrics))



if __name__ == '__main__':
    main()

This resulted in:

2023-05-12 15:36:06,149 ERROR trial_runner.py:712 -- Trial TrainableNN_86444856: Error stopping trial.
Traceback (most recent call last):
  File "/home/phahn/miniconda3/envs/dal-toolbox/lib/python3.9/site-packages/ray/tune/execution/trial_runner.py", line 700, in stop_trial
    self._scheduler_alg.on_trial_complete(
  File "/home/phahn/miniconda3/envs/dal-toolbox/lib/python3.9/site-packages/ray/tune/schedulers/hyperband.py", line 299, in on_trial_complete
    self.on_trial_remove(trial_runner, trial)
  File "/home/phahn/miniconda3/envs/dal-toolbox/lib/python3.9/site-packages/ray/tune/schedulers/hyperband.py", line 293, in on_trial_remove
    self._process_bracket(trial_runner, bracket)
  File "/home/phahn/miniconda3/envs/dal-toolbox/lib/python3.9/site-packages/ray/tune/schedulers/hyperband.py", line 270, in _process_bracket
    raise TuneError(
ray.tune.error.TuneError: Trial with unexpected good status encountered: PENDING

@justinvyu Here you go if you haven’t had the time, yet :slight_smile:

1 Like

Thanks, this is very helpful!

By the way, I have a PR in flight to mitigate this issue: [Tune] Fix hyperband scheduler raising an error for good `PENDING` trials by justinvyu · Pull Request #35338 · ray-project/ray · GitHub

Were you able to reproduce the 2nd error consistently?

    self.running.remove(trial_id)
KeyError: 'f23e8d5f'

I did not see that show up in my testing.

1 Like

Thanks very much for your effort!

Sadly, i could not reproduce the second error in a simple manner yet.