Python shell is killed while running fine tuning models

Using Ray Cluster in Kubernetes and connecting from external Jupyter Notebook.

While running my notebook to fine tune a hugging face model, the kernel is killed in the step:

from ray.train.huggingface import HuggingFaceTrainer
from ray.air.config import ScalingConfig
from ray.data.preprocessors import Chain

trainer = HuggingFaceTrainer(
    trainer_init_per_worker=trainer_init_per_worker,
    trainer_init_config={
        "batch_size": 16,
        "epochs": 1,
    },
    scaling_config=ScalingConfig(
        num_workers=num_workers,
        use_gpu=use_gpu,
        resources_per_worker={"GPU": 1, "CPU": cpus_per_worker},
    ),
    datasets={"train": ray_datasets["train"], "evaluation": ray_datasets["validation"]},
    preprocessor=Chain(splitter, tokenizer),
)

results = trainer.fit()

trainer.fit() trains the model successfully but at the end the Kernel is killed while providing a warning :

UserWarning: Ray Client is attempting to retrieve a 5.53 GiB object over the network, which may be slow. Consider serializing the object to a file and using S3 or rsync instead

I’m unable find any Docs which can help me in solving the issue by using the serializing solution provided.

Any help would be much appreciated, Thank!

Versions:
Kubernetes Version : v1.25.6
Ray Version : 2.3.1
Python Version : 3.8

@nikhil.das What is your trainer_init_per_worker code looks like? And where are the datasets.
I assume you using HF dataset and then converting them into Ray Data, right? where is all that happening?

Also, where is the HuggingFace model created?

cc; @Yard1

Hi @Jules_Damji,
My sincere apologies for the late reply, I was busy with other high priority works.

Here is my trainer_init_per_worker code snippet

import evaluate
from transformers import Trainer, TrainingArguments
from transformers import (
    GPTJForCausalLM,
    AutoTokenizer,
    default_data_collator,
    AutoModelForCausalLM
)
from transformers.utils.logging import disable_progress_bar, enable_progress_bar
import torch

from ray.air import session
def trainer_init_per_worker(train_dataset, eval_dataset=None, **config):
    # Use the actual number of CPUs assigned by Ray
    os.environ["OMP_NUM_THREADS"] = str(
        session.get_trial_resources().bundles[-1].get("CPU", 1)
    )
    # Enable tf32 for better performance
    torch.backends.cuda.matmul.allow_tf32 = True

    batch_size = config.get("batch_size", 4)
    epochs = config.get("epochs", 1)
    warmup_steps = config.get("warmup_steps", 0)
    learning_rate = config.get("learning_rate", 0.00002)
    weight_decay = config.get("weight_decay", 0.01)

    deepspeed = {
        "fp16": {
            "enabled": "auto",
            "initial_scale_power": 8,
        },
        "bf16": {"enabled": "auto"},
        "optimizer": {
            "type": "AdamW",
            "params": {
                "lr": "auto",
                "betas": "auto",
                "eps": "auto",
            },
        },
        "zero_optimization": {
            "stage": 3,
            "offload_optimizer": {
                "device": "cpu",
                "pin_memory": True,
            },
            "offload_param": {
                "device": "cpu",
                "pin_memory": True,
            },
            "overlap_comm": True,
            "contiguous_gradients": True,
            "reduce_bucket_size": "auto",
            "stage3_prefetch_bucket_size": "auto",
            "stage3_param_persistence_threshold": "auto",
            "gather_16bit_weights_on_model_save": True,
            "round_robin_gradients": True,
        },
        "gradient_accumulation_steps": "auto",
        "gradient_clipping": "auto",
        "steps_per_print": 10,
        "train_batch_size": "auto",
        "train_micro_batch_size_per_gpu": "auto",
        "wall_clock_breakdown": False,
    }

    print("Preparing training arguments")
    training_args = TrainingArguments(
        "output",
        per_device_train_batch_size=batch_size,
        logging_steps=1,
        save_strategy="no",
        per_device_eval_batch_size=batch_size,
        learning_rate=learning_rate,
        weight_decay=weight_decay,
        warmup_steps=warmup_steps,
        label_names=["input_ids", "attention_mask"],
        num_train_epochs=epochs,
        push_to_hub=False,
        disable_tqdm=True,  # declutter the output a little
        fp16=True,
        gradient_checkpointing=True,
        deepspeed=deepspeed,
    )
    disable_progress_bar()

    tokenizer = AutoTokenizer.from_pretrained(model_name)
    tokenizer.pad_token = tokenizer.eos_token

    print("Loading model")

    model = GPTJForCausalLM.from_pretrained(model_name, use_cache=False)
    model.resize_token_embeddings(len(tokenizer))

    print("Model loaded")

    enable_progress_bar()

    metric = evaluate.load("accuracy")

    def compute_metrics(eval_pred):
        logits, labels = eval_pred
        predictions = np.argmax(logits, axis=-1)
        return metric.compute(predictions=predictions, references=labels)

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        compute_metrics=compute_metrics,
        tokenizer=tokenizer,
        data_collator=default_data_collator,
    )
    return trainer

The trainer function

from ray.train.huggingface import HuggingFaceTrainer
from ray.air.config import ScalingConfig
from ray.data.preprocessors import Chain


trainer = HuggingFaceTrainer(
    trainer_init_per_worker=trainer_init_per_worker,
    trainer_init_config={
        "batch_size": 16,  # per device
        "epochs": 1,
    },
    scaling_config=ScalingConfig(
        num_workers=num_workers,
        use_gpu=use_gpu,
        resources_per_worker={"GPU": 1, "CPU": cpus_per_worker},
    ),
    datasets={"train": ray_datasets["train"], "evaluation": ray_datasets["validation"]},
    preprocessor=Chain(splitter, tokenizer),
)

Yes, we are using HF dataset, our notebook has code that converts HF datasets to Ray Data.

The Model is being created from the above code snippet.