Getting Tune to read Train checkpoint in ray.train.report

1. Severity of the issue: (select one)
None: I’m just curious or want clarification.
Low: Annoying but doesn’t hinder my work.
Medium: Significantly affects my productivity but can find a workaround.
High: Completely blocks me.

2. Environment:

  • Ray version: 2.44.1
  • Python version: 3.10.12
  • OS: Linux
  • Cloud/Infrastructure:
  • Other libs/tools (if relevant):

3. Repro steps / sample code:

Train function per worker

def train_fn_per_worker(train_loop_config):
    ...  # Create my model

    model = ray.train.torch.prepare_model(model)
    criterion, optimizer = createLossFunction(model, weights=loss_weights, lr=lr)

    ... # Have own function to create loaders and return sampler

    train_loader = ray.train.torch.prepare_data_loader(train_loader)
    val_loader = ray.train.torch.prepare_data_loader(val_loader)

    train(model, optimizer, criterion, train_loader, train_sampler, model.device, num_epochs, save_epoch, val_loader=val_loader, label_map=label_map)

Train Driver

def tune_train_driver_fn(config, train_config):
    trainer = ray.train.torch.TorchTrainer(
        train_fn_per_worker,
        train_loop_config=train_config,
        scaling_config=ray.train.ScalingConfig(
            num_workers=train_config["num_workers"],
            use_gpu=True,
        ),
        run_config=ray.train.RunConfig(
            name=f"{model_name}_train",
            storage_path="/workspace/ray_train_results",
            callbacks=[TuneReportCallback()],
            checkpoint_config=ray.train.CheckpointConfig(
                num_to_keep=2,
                checkpoint_score_attribute="iou",
                checkpoint_score_order="max"
            ),
        ),
        torch_config=ray.train.torch.TorchConfig(
            backend="gloo",
        )
    )
    trainer.fit()

Tuner

ray.init()
param_space = {
    "epochs": tune.choice([50, 100, 150, 200]),
    "encoder_name": tune.choice(["resnet50", "resnet101", "resnet152", "mit_b0", "mit_b1", "mit_b2", "mit_b3", "mit_b4",
                                 "efficientnet-b0", "efficientnet-b1", "efficientnet-b2", "efficientnet-b3", "efficientnet-b4"]),
    "loss_weights": tune.choice([None, torch.tensor([1.0, 1.0, 1.0, 1.75]), torch.tensor([1.0, 1.0, 1.0, 1.5]), torch.tensor([1.0, 1.0, 1.0, 1.25])]),
    "lr": tune.loguniform(1e-5, 1e-3)
}

train_config = {
     ... # configuration for just setting up training
}

scheduler = ASHAScheduler(
                metric="iou",
                mode="max",
                max_t=num_epochs,
                grace_period=1,
                reduction_factor=2
            )
os.makedirs("/workspace/ray_tune_results", exist_ok=True)
tuner = tune.Tuner(
                tune.with_parameters(tune_train_driver_fn, train_config=train_config),
                param_space=param_space,
                tune_config=tune.TuneConfig(
                    num_samples=20,
                    max_concurrent_trials=2,
                    scheduler=scheduler,
                ),
                run_config=ray.tune.RunConfig(
                    name=f"{model_name}_tune",
                    storage_path="/workspace/ray_tune_results",
                )
            )

results = tuner.fit()

How I save checkpoint

def train(model, optimizer, criterion, train_loader, train_sampler, model.device, num_epochs, save_epoch, val_loader=val_loader, label_map=label_map):
    for epoch in range(num_epochs):
        model.train()
        train_sampler.set_epoch(epoch)
        running_train_loss = 0.0
        for images, masks in train_loader:
            ... # Training


        model.eval()
        running_val_loss = 0.0
        all_true_masks = []
        all_pred_masks = []
        with torch.no_grad():
           for images, masks in val_loader:
                 ... # Validation

        all_true_masks = np.concatenate(all_true_masks)
        all_pred_masks = np.concatenate(all_pred_masks)
        f1 = f1_score(all_true_masks, all_pred_masks, average='weighted')
        iou = np.mean([np.sum((all_true_masks == i) & (all_pred_masks == i)) / np.sum((all_true_masks == i) | (all_pred_masks == i)) for i in range(len(label_map.colormap))])
        dice = np.mean([2 * np.sum((all_true_masks == i) & (all_pred_masks == i)) / (np.sum(all_true_masks == i) + np.sum(all_pred_masks == i)) for i in range(len(label_map.colormap))])
        metrics = {
            'train_loss': avg_train_loss,
            'val_loss': avg_val_loss,
            'f1_score': f1,
            'iou': iou,
            'dice': dice
        }
        with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
            checkpoint = None
            if (epoch+1) % save_epoch == 0 and ray.train.get_context().get_world_rank() == 0:
                torch.save(
                    model.module.state_dict(),
                    os.path.join(temp_checkpoint_dir, f'{model_name}_checkpoint_{epoch+1}.pt')
                )
                checkpoint = Checkpoint.from_directory(temp_checkpoint_dir)
            ray.train.report(metrics=metrics, checkpoint=checkpoint)

4. What happened vs. what you expected:

  • Expected: To not error
  • Actual:
2025-04-03 18:30:44,323 ERROR tune_controller.py:1331 -- Trial task failed for trial tune_train_driver_fn_eb954_00000
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/ray/air/execution/_internal/event_manager.py", line 110, in resolve_future
    result = ray.get(future)
  File "/usr/local/lib/python3.10/dist-packages/ray/_private/auto_init_hook.py", line 21, in auto_init_wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/ray/_private/client_mode_hook.py", line 103, in wrapper
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/ray/_private/worker.py", line 2782, in get
    values, debugger_breakpoint = worker.get_objects(object_refs, timeout=timeout)
  File "/usr/local/lib/python3.10/dist-packages/ray/_private/worker.py", line 929, in get_objects
    raise value.as_instanceof_cause()
ray.exceptions.RayTaskError(AttributeError): ray::ImplicitFunc.train() (pid=545996, ip=172.17.0.2, actor_id=4e566e69fd38014739f3082801000000, repr=tune_train_driver_fn)
  File "/usr/local/lib/python3.10/dist-packages/ray/tune/trainable/trainable.py", line 330, in train
    raise skipped from exception_cause(skipped)
  File "/usr/local/lib/python3.10/dist-packages/ray/air/_internal/util.py", line 107, in run
    self._ret = self._target(*self._args, **self._kwargs)
  File "/usr/local/lib/python3.10/dist-packages/ray/tune/trainable/function_trainable.py", line 45, in <lambda>
    training_func=lambda: self._trainable_func(self.config),
  File "/usr/local/lib/python3.10/dist-packages/ray/tune/trainable/function_trainable.py", line 261, in _trainable_func
    output = fn()
  File "/usr/local/lib/python3.10/dist-packages/ray/tune/trainable/util.py", line 130, in inner
    return trainable(config, **fn_kwargs)
  File "/workspace/oi-cvml/CustomTrainer.py", line 556, in tune_train_driver_fn
    trainer.fit()
  File "/usr/local/lib/python3.10/dist-packages/ray/train/v2/api/data_parallel_trainer.py", line 112, in fit
    result = self._initialize_and_run_controller(
  File "/usr/local/lib/python3.10/dist-packages/ray/train/v2/api/data_parallel_trainer.py", line 192, in _initialize_and_run_controller
    controller.run()
  File "/usr/local/lib/python3.10/dist-packages/ray/train/v2/_internal/execution/controller/controller.py", line 451, in run
    self._run_control_loop_iteration()
  File "/usr/local/lib/python3.10/dist-packages/ray/train/v2/_internal/execution/controller/controller.py", line 441, in _run_control_loop_iteration
    result = self._step()
  File "/usr/local/lib/python3.10/dist-packages/ray/train/v2/_internal/execution/controller/controller.py", line 361, in _step
    worker_group_status = self._poll_workers()
  File "/usr/local/lib/python3.10/dist-packages/ray/train/v2/_internal/execution/controller/controller.py", line 252, in _poll_workers
    status = self._worker_group.poll_status(timeout=self._health_check_interval_s)
  File "/usr/local/lib/python3.10/dist-packages/ray/train/v2/_internal/execution/worker_group/worker_group.py", line 439, in poll_status
    callback.after_worker_group_poll_status(worker_group_poll_status)
  File "/usr/local/lib/python3.10/dist-packages/ray/train/v2/_internal/execution/checkpoint/report_handler.py", line 98, in after_worker_group_poll_status
    callback.after_report(
  File "/usr/local/lib/python3.10/dist-packages/ray/train/v2/_internal/execution/checkpoint/checkpoint_manager.py", line 269, in after_report
    self.register_checkpoint(
  File "/usr/local/lib/python3.10/dist-packages/ray/train/v2/_internal/execution/checkpoint/checkpoint_manager.py", line 130, in register_checkpoint
    self._write_state_to_storage()
  File "/usr/local/lib/python3.10/dist-packages/ray/train/v2/_internal/execution/checkpoint/checkpoint_manager.py", line 224, in _write_state_to_storage
    checkpoint_manager_snapshot = self._save_state()
  File "/usr/local/lib/python3.10/dist-packages/ray/train/v2/_internal/execution/checkpoint/checkpoint_manager.py", line 162, in _save_state
    return manager_snapshot.model_dump_json()
AttributeError: '_CheckpointManagerState' object has no attribute 'model_dump_json'

Error happens immediately after finishing report

Even copied the skeleton code from here Hyperparameter Tuning with Ray Tune — Ray 2.44.1 to a new file and it still failed with the same error.

Only thing added was setting os.environ['RAY_TRAIN_V2_ENABLED'] = '1'

To fix, modify the file
/usr/local/lib/python3.10/dist-packages/ray/train/v2/_internal/execution/checkpoint/checkpoint_manager.py, line 162

by changing .model_dump_json() to .json()