I am using the below code snippet in my model -
Per my understanding, the model.pth file should be saved at "content/Run2"
I am able to view **checkpoint file ** under /content/Run2/DEFAULT_2021-07-04_19-05-07/DEFAULT_bf9e6_00000_0_batch_size=4,lr=5.6e-05_2021-07-04_19-05-08/checkpoint_000004/
But, it seems that .pth file for model’s state_dict is not saved here.
I have tried a few trouble shooting steps, seems there is a miss.
Please suggest how can I save and further load model’s state_dict.
Thanks
def train_cifar( config, checkpoint_dir=None, data_dir=None):
train_scores , train_losses, val_losses ,val_scores = [],[],[],[]
for fold in range(folds):
best_score = None
print('Creating Datasets..')
train_,val_,test_ = create_dataset(train,test , fold ,TRAIN_BATCH , VAL_BATCH, TEST_BATCH)
def build_model():
model = RobertaForSequenceClassification.from_pretrained('roberta-base' , num_labels=1 )
return model
model=build_model()
model.to(device)
gc.collect()
FileLinks('./')
torch.cuda.empty_cache()
if checkpoint_dir:
model_state_dict = model.load_state_dict(torch.load(os.path.join(checkpoint_dir, "checkpoint")))
print(f'LR : {learning_rate} , EPOCHS : {EPOCHS} FOLDS : {folds} BATCH SIZE : {TRAIN_BATCH}')
for epoch in tqdm(range(EPOCHS)):
print(f'fold : {fold} Epoch : {epoch}')
print('training...')
train_loss, train_rmse = train_model(train_,epoch,model=model,learning_rate=learning_rate)
train_losses.append(train_loss)
train_scores.append(train_rmse)
print(f'train error : {train_rmse}')
print('validating')
val_preds, val_loss, val_score = validation(val_,model)
val_losses.append(val_loss)
val_scores.append(val_score)
print(f'val error:{val_score}')
torch.cuda.empty_cache()
with tune.checkpoint_dir(epoch) as checkpoint_dir:
torch.save((model.state_dict()),os.path.join(checkpoint_dir , 'checkpoint'))
tune.report(loss=val_loss)
result = tune.run(
partial(train_cifar),
# resources_per_trial={"cpu": 8, "gpu": gpus_per_trial},
config=config,
num_samples=2,
scheduler=scheduler,
resources_per_trial={"cpu": 2, "gpu": gpus_per_trial},
progress_reporter=reporter,
local_dir="/content/Run2"
)