Ray train not work in pretrain model

How severe does this issue affect your experience of using Ray?

  • None: Just asking a question out of curiosity
  • Low: It annoys or frustrates me for a moment.
  • Medium: It contributes to significant difficulty to complete my task, but I can work around it.
  • High: It blocks me to complete my task.

when i use huggingface transformers pretrained model , this is not work

import ray
import ray.train as train
import torch
from ray.train.torch import TorchConfig
from ray.train.trainer import Trainer
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from transformers import BertTokenizer, BertForMaskedLM


model_name = "cahya/bert-base-indonesian-522M"
# model = BertForPreTraining.from_pretrained(model_name, cache_dir=cache_dir)
# model_id = ray.put(model)
model_dir_name = model_name.split("/")[1]
tokenizer = BertTokenizer.from_pretrained(model_name, cache_dir=f"~/models/{model_dir_name}")
from transformers import DataCollatorForLanguageModeling

data_collator = DataCollatorForLanguageModeling(
    tokenizer, mlm=True, mlm_probability=0.15

def my_collate(batch):
    # input_ids, labels(mlm)
    batch["labels"] = data_collator([batch["input_ids"]])["labels"]
    for k, v in batch.items():
        if not isinstance(v, torch.Tensor):
            batch[k] = torch.tensor(v, dtype=torch.long)
    return batch

class MyDataSet(Dataset):
    def __init__(self):
        txt_path = "/data/all_corpus.txt"
        # print("hostname: ", socket.gethostname())
        with open(txt_path) as f:
            self.lines = [f.readline() for i in range(5000)]

    def __getitem__(self, item):
        text = self.lines[item].strip()
        input = tokenizer(text, max_length=256, truncation=True, return_tensors=None, padding="max_length",
        input_dict = my_collate(input)
        input_ids = input_dict["input_ids"]
        attention_mask = input_dict["attention_mask"]
        token_type_ids = input_dict["token_type_ids"]
        labels = input_dict["labels"]
        return (input_ids, attention_mask, token_type_ids, labels)

    def __len__(self):
        return len(self.lines)

def train_func(config: dict):
    batch_size = config["batch_size"]
    epochs = config["epochs"]

    worker_batch_size = batch_size // train.world_size()
    print("worker_batch_size: ", worker_batch_size)
    # Create data loaders.
    train_dataloader = DataLoader(MyDataSet(), batch_size=worker_batch_size, shuffle=True)

    train_dataloader = train.torch.prepare_data_loader(train_dataloader)

    # Create model.
    model = BertForMaskedLM.from_pretrained(model_name, cache_dir=f"~/models/{model_dir_name}")
    model = train.torch.prepare_model(model)

    for _ in range(epochs):
        train_epoch(train_dataloader, model)

def train_epoch(dataloader, model):
    size = len(dataloader.dataset) // train.world_size()
    optimizer = torch.optim.Adam(model.parameters(), lr=2e-5)
    for index, batch in tqdm(enumerate(dataloader), desc="train desc: ", total=len(dataloader)):
        # batch
        input_ids, attention_mask, token_type_ids, labels = batch
        print("batch: ",batch)
        output = model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, labels=labels)
        print("output: ", output)
        # loss = output.loss
        # # print(loss)
        # optimizer.zero_grad()
        # loss.backward()
        # optimizer.step()
        # loss = loss.item()
        # print(f"loss: {loss:>7f}  [{index:>5d}/{train.world_size():>5d}]")

if __name__ == '__main__':
    trainer = Trainer(backend=TorchConfig(backend="gloo"), num_workers=5, use_gpu=True)
    result = trainer.run(
        config={"batch_size": 60, "epochs": 5})

My ray cluster info: 8node, each node 4gpu
When I set num_workers=5, by printing “batch:” info and “output:” info, it can be found that the model is stuck after four forward computations
When I set num_workers=4, I found that it works fine
I think it might be a communication issue between multiple nodes, but I don’t know how to troubleshoot as there isn’t any log output available