Hey guys! Do you know how to start fault tolerance in RayTrain. It is suggested in the document that I should implement loading and saving checkpoint. So, is there anything I need to set apart from that if I want to use fault tolerance feature of RayTrain.
This is the part of my
checkpoint = sgd.load_checkpoint() if not checkpoint == None: model = checkpoint.get("model", 0) start_epoch = checkpoint.get("epoch", -1) + 1 else: model = create_model('resnet-18', 10) start_epoch = 0 for epoch in range(start_epoch,epochs): forward,backward,step,timedur=train_epoch(train_dataloader, model, loss_fn, optimizer, device,epoch) sgd.save_checkpoint(epoch=epoch, model=model) sgd.report(forward=forward, backward=backward,step=step,time=timedur)
Are these settings enough? Or I need to set on other parameters or codes?