How to start fault tolerance

Hey guys! Do you know how to start fault tolerance in RayTrain. It is suggested in the document that I should implement loading and saving checkpoint. So, is there anything I need to set apart from that if I want to use fault tolerance feature of RayTrain.

This is the part of my train_func

checkpoint = sgd.load_checkpoint()
if not checkpoint == None:
    model = checkpoint.get("model", 0)
    start_epoch = checkpoint.get("epoch", -1) + 1
else:
    model = create_model('resnet-18', 10)
    start_epoch = 0

    for epoch in range(start_epoch,epochs):
        forward,backward,step,timedur=train_epoch(train_dataloader, model, loss_fn, optimizer, device,epoch)
        sgd.save_checkpoint(epoch=epoch, model=model)
        sgd.report(forward=forward, backward=backward,step=step,time=timedur)

Are these settings enough? Or I need to set on other parameters or codes?

Hey @daxixi! The way you are using checkpointing looks correct to me! The only thing is you probably want to move the for loop to outside the if/else block. Otherwise no training will actually happen if you are recovering from a checkpoint.

Thanks very much for your reply and reminding me of that!