What is the best way to load the saved checkpoint file? I used train.save_checkpoint
to save the best models. Script to save best model is following
# training and validation loops
loss_train, loss_valid, best_loss, best_epoch, best_iou, best_f1 = [], [], 9999, 9999, 0, 0
since = time.time()
for epoch in range(1, epochs + 1):
# training step
print_with_rank(f"\n{'--' * 15} EPOCH: {epoch} | {epochs} {'--' * 15}\n")
epoch_train_loss, epoch_train_iou, epoch_train_f1 = train_one_epoch(net, trainloader, criterion, optimizer, lr_scheduler, scaler)
print_with_rank("\nPhase: train | Loss: {:.4f} | IoU: {:.4f} | F1: {:.4f}".format(epoch_train_loss, epoch_train_iou, epoch_train_f1))
# validation step
epoch_val_loss, epoch_val_iou, epoch_val_f1 = valid_one_epoch(net, valloader, criterion)
print_with_rank("\nPhase: val | Loss: {:.4f} | IoU: {:.4f} | F1: {:.4f}".format(epoch_val_loss, epoch_val_iou, epoch_val_f1))
# logs
loss_train.append(epoch_train_loss)
loss_valid.append(epoch_val_loss)
diff = abs(np.round(epoch_val_iou, 4) - np.round(best_iou , 4))
if (np.round(epoch_val_iou, 4) > np.round(best_iou , 4)) and diff >= 1e-3:
print_with_rank("IoU improved from {:.4f} to {:.4f}.".format(best_iou, epoch_val_iou))
best_loss = epoch_val_loss
best_model_wts = copy.deepcopy(net.module.state_dict())
best_epoch = epoch
best_iou = epoch_val_iou
best_f1 = epoch_val_f1
train.report(loss=epoch_val_loss, f1=epoch_val_f1, iou=epoch_val_iou)
time_elapsed = time.time() - since
print_with_rank("Training complete in {:.0f}m {:.0f}s".format(time_elapsed // 60, time_elapsed % 60))
consume_prefix_in_state_dict_if_present(best_model_wts, "module.")
# save best model
print_with_rank("Saving best model")
train.save_checkpoint(epoch=best_epoch, model_weights=best_model_wts)
print_with_rank("Best model loss: {:.4f} and corresponding iou: {:.4f}, f1 : {:.4f}".format(best_loss, best_iou, best_f1))
Saved models tree
outputs/raytune_result/tune_function_2022-03-26_07-33-13/
├── basic-variant-state-2022-03-26_07-33-13.json
├── experiment_state-2022-03-26_07-33-13.json
├── tune_function_fe5ef_00000_0_loss_fn=tversky,lr=0.064116,lr_warmup_decay=0.018869,lr_warmup_epochs=5,lr_warmup_method=linear,weight_2022-03-26_07-33-13
│ ├── checkpoint_000000
│ │ └── checkpoint
│ ├── events.out.tfevents.1648279993.dgxstation-a100
│ ├── params.json
│ ├── params.pkl
│ ├── progress.csv
│ └── result.json
├── tune_function_fe5ef_00001_1_loss_fn=focal,lr=0.028107,lr_warmup_decay=0.022144,lr_warmup_epochs=5,lr_warmup_method=constant,weight_2022-03-26_07-33-15
│ ├── checkpoint_000000
│ │ └── checkpoint
│ ├── events.out.tfevents.1648279995.dgxstation-a100
│ ├── params.json
│ ├── params.pkl
│ ├── progress.csv
│ └── result.json
When i use normal torch.load
function to load checkpoint
file it gives following error.
>>> import torch
>>> model_state = torch.load(os.path.join(path, "checkpoint"))
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/home/mycomp/anaconda3/envs/bop/lib/python3.8/site-packages/torch/serialization.py", line 608, in load
return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args)
File "/home/mycomp/anaconda3/envs/bop/lib/python3.8/site-packages/torch/serialization.py", line 779, in _legacy_load
raise RuntimeError("Invalid magic number; corrupt file?")
RuntimeError: Invalid magic number; corrupt file?
I tried to use load_checkpoint_from_path with this I am able to load the checkpoint file but when i try to match the weights to the architecture it says key mistmatch.
import torch
from pathlib import Path
from typing import List, Optional, Dict, Union, Callable
from ray import cloudpickle
def load_checkpoint_from_path(checkpoint_to_load: Union[str, Path]) -> Dict:
"""Utility function to load a checkpoint Dict from a path."""
checkpoint_path = Path(checkpoint_to_load).expanduser()
if not checkpoint_path.exists():
raise ValueError(f"Checkpoint path {checkpoint_path} does not exist.")
with checkpoint_path.open("rb") as f:
return cloudpickle.load(f)
net = Mymodel()
model_wts = load_checkpoint_from_path(full_path_to_checkpoint)
net.load_state_dict(model_wts)