Best approach to load saved checkpoint

What is the best way to load the saved checkpoint file? I used train.save_checkpoint to save the best models. Script to save best model is following

    # training and validation loops
    loss_train, loss_valid, best_loss, best_epoch, best_iou, best_f1 = [], [], 9999, 9999, 0, 0
    since = time.time()
    for epoch in range(1, epochs + 1):
        # training step
        print_with_rank(f"\n{'--' * 15} EPOCH: {epoch} | {epochs} {'--' * 15}\n")
        epoch_train_loss, epoch_train_iou, epoch_train_f1 = train_one_epoch(net, trainloader, criterion, optimizer, lr_scheduler, scaler)
        print_with_rank("\nPhase: train | Loss: {:.4f} | IoU: {:.4f} | F1: {:.4f}".format(epoch_train_loss, epoch_train_iou, epoch_train_f1))
        # validation step
        epoch_val_loss, epoch_val_iou, epoch_val_f1 = valid_one_epoch(net, valloader, criterion)
        print_with_rank("\nPhase: val | Loss: {:.4f} | IoU: {:.4f} | F1: {:.4f}".format(epoch_val_loss, epoch_val_iou, epoch_val_f1))
        # logs
        loss_train.append(epoch_train_loss)
        loss_valid.append(epoch_val_loss)

        diff = abs(np.round(epoch_val_iou, 4) - np.round(best_iou , 4))
        if (np.round(epoch_val_iou, 4) > np.round(best_iou , 4)) and diff >= 1e-3:
            print_with_rank("IoU improved from {:.4f} to {:.4f}.".format(best_iou, epoch_val_iou))
            best_loss = epoch_val_loss
            best_model_wts = copy.deepcopy(net.module.state_dict())
            best_epoch = epoch
            best_iou = epoch_val_iou
            best_f1 = epoch_val_f1

        train.report(loss=epoch_val_loss, f1=epoch_val_f1, iou=epoch_val_iou)
            
    time_elapsed = time.time() - since
    print_with_rank("Training complete in {:.0f}m {:.0f}s".format(time_elapsed // 60, time_elapsed % 60))        

    consume_prefix_in_state_dict_if_present(best_model_wts, "module.")
    # save best model
    print_with_rank("Saving best model")
    train.save_checkpoint(epoch=best_epoch, model_weights=best_model_wts)
    print_with_rank("Best model loss: {:.4f} and corresponding iou: {:.4f}, f1 : {:.4f}".format(best_loss, best_iou, best_f1))
    

Saved models tree

outputs/raytune_result/tune_function_2022-03-26_07-33-13/
├── basic-variant-state-2022-03-26_07-33-13.json
├── experiment_state-2022-03-26_07-33-13.json
├── tune_function_fe5ef_00000_0_loss_fn=tversky,lr=0.064116,lr_warmup_decay=0.018869,lr_warmup_epochs=5,lr_warmup_method=linear,weight_2022-03-26_07-33-13
│   ├── checkpoint_000000
│   │   └── checkpoint
│   ├── events.out.tfevents.1648279993.dgxstation-a100
│   ├── params.json
│   ├── params.pkl
│   ├── progress.csv
│   └── result.json
├── tune_function_fe5ef_00001_1_loss_fn=focal,lr=0.028107,lr_warmup_decay=0.022144,lr_warmup_epochs=5,lr_warmup_method=constant,weight_2022-03-26_07-33-15
│   ├── checkpoint_000000
│   │   └── checkpoint
│   ├── events.out.tfevents.1648279995.dgxstation-a100
│   ├── params.json
│   ├── params.pkl
│   ├── progress.csv
│   └── result.json

When i use normal torch.load function to load checkpoint file it gives following error.

>>> import torch
>>> model_state = torch.load(os.path.join(path, "checkpoint"))
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/mycomp/anaconda3/envs/bop/lib/python3.8/site-packages/torch/serialization.py", line 608, in load
    return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args)
  File "/home/mycomp/anaconda3/envs/bop/lib/python3.8/site-packages/torch/serialization.py", line 779, in _legacy_load
    raise RuntimeError("Invalid magic number; corrupt file?")
RuntimeError: Invalid magic number; corrupt file?

I tried to use load_checkpoint_from_path with this I am able to load the checkpoint file but when i try to match the weights to the architecture it says key mistmatch.

import torch
from pathlib import Path
from typing import List, Optional, Dict, Union, Callable
from ray import cloudpickle

def load_checkpoint_from_path(checkpoint_to_load: Union[str, Path]) -> Dict:
    """Utility function to load a checkpoint Dict from a path."""
    checkpoint_path = Path(checkpoint_to_load).expanduser()
    if not checkpoint_path.exists():
        raise ValueError(f"Checkpoint path {checkpoint_path} does not exist.")
    with checkpoint_path.open("rb") as f:
        return cloudpickle.load(f)


net = Mymodel()
model_wts = load_checkpoint_from_path(full_path_to_checkpoint)
net.load_state_dict(model_wts)

Hey @dudeperf3ct do you mind sharing the full reproducible script that you are running if possible? In particular, how you are using the Trainer or Ray Tune as well.

@amogkam here’s the gist : train_tuner_seg.py · GitHub

I was able to resolve the issue using second approach. The weights are present in model_wts['model_weights'] key and able to load the model successfully.