Unexpected node deaths cannot be recovered from checkpoints

I ran ‘ray start --head’ on a head node in Docker, and connected two nodes to the head node using ‘ray start --address’ in Docker as well. When I ran the training code on the head node, it used 2 CPUs out of 4 CPUs. When I stopped the Docker container on one worker node that was assigned by Ray to train, it simulated a node failure, and the training failed as expected but automatically restarted. It showed “TorchTrainer pid=131, ip=95.217.176.214) Restored on 95.217.176.214 from checkpoint: Checkpoint(filesystem=s3, path=random-distribution/transformer-hugging-face/TorchTrainer_6eea7_00000_0_2024-07-25_23-51-06/checkpoint_000004)”, but the training epoch started from 1, and the loss was high. It did not resume from the last checkpoint. How can I solve this? Thanks.

This is my code:

from datasets import load_dataset
from transformers import Trainer, TrainingArguments, AutoTokenizer, AutoModelForSequenceClassification

# import ray
import ray
from ray import train
from ray.train import ScalingConfig, RunConfig, CheckpointConfig, Checkpoint, FailureConfig
from ray.train.torch import TorchTrainer

from ray.train.huggingface.transformers import RayTrainReportCallback

import pyarrow

def train_func():
    # load dataset
    dataset = load_dataset("imdb")

    # load tokenizer and model
    tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
    model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=5)

    # init tokenize function
    def tokenize_function(examples):
        return tokenizer(examples["text"], padding="max_length", truncation=True)

    # Select a small subset of the data for training and evaluation
    small_train_dataset = dataset["train"].select(range(5)).map(tokenize_function, batched=True)
    small_eval_dataset = dataset["test"].select(range(5)).map(tokenize_function, batched=True)

    # get checkpoint
    checkpoint = train.get_checkpoint()
    print("checkpoint",checkpoint)

    # init training args
    training_args = TrainingArguments(
        output_dir="distilbert",
        evaluation_strategy="epoch",
        save_strategy="epoch",
        report_to="none",
        per_device_train_batch_size=1,
        per_device_eval_batch_size=1,
        # only train 15 epochs
        num_train_epochs=15,
        resume_from_checkpoint=checkpoint
    )

    # init trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=small_train_dataset,
        eval_dataset=small_eval_dataset,
    )

    # add callback
    trainer.add_callback(RayTrainReportCallback())
    trainer = ray.train.huggingface.transformers.prepare_trainer(trainer)

    # start train
    trainer.train()

# config s3
fs = pyarrow.fs.S3FileSystem(
    # For safety reasons, I delete the key.
    # but it can save the checkpoint successfully when the key is set.
    access_key="",
    secret_key="",
    region="ap-southeast-2",
    check_directory_existence_before_creation=True
)

ray_trainer = TorchTrainer(
    train_func,
    run_config=RunConfig(
        storage_filesystem=fs,
        storage_path="random-distribution",
        name="distilbert",
        checkpoint_config=CheckpointConfig(
            num_to_keep=3,
            checkpoint_score_attribute="eval_loss",  # The monitoring metric
            checkpoint_score_order="min",
        ),
        failure_config=FailureConfig(max_failures=-1)
    ),
    scaling_config=ScalingConfig(num_workers=1, use_gpu=False)
)

# start training
ray_trainer.fit()

# shutdown
ray.shutdown()