Hey guys! Do you know how to start fault tolerance in RayTrain. It is suggested in the document that I should implement loading and saving checkpoint. So, is there anything I need to set apart from that if I want to use fault tolerance feature of RayTrain.
This is the part of my train_func
checkpoint = sgd.load_checkpoint()
if not checkpoint == None:
model = checkpoint.get("model", 0)
start_epoch = checkpoint.get("epoch", -1) + 1
else:
model = create_model('resnet-18', 10)
start_epoch = 0
for epoch in range(start_epoch,epochs):
forward,backward,step,timedur=train_epoch(train_dataloader, model, loss_fn, optimizer, device,epoch)
sgd.save_checkpoint(epoch=epoch, model=model)
sgd.report(forward=forward, backward=backward,step=step,time=timedur)
Are these settings enough? Or I need to set on other parameters or codes?