Ray Cluster, why does the program freeze and stop executing when the number of GPUs required by the program requires the GPUs of two machines

Thank you very much for the staff of Ray, who made great software. I would like to ask a question. My ray cluster has two machines, one with 2 GPUs and one with 4 GPUs. When I run the program, specify the number of GPUs to be 6 Sometimes, it will freeze and not train, what is the reason?

My program is as follows

import argparse
import time
import os
from typing import Dict
import ray
from ray.air import session
import subprocess

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

import ray.train as train
from ray.train.torch import TorchTrainer
from ray.air.config import ScalingConfig

# Download training data from open datasets.
training_data = datasets.MNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
)

# Download test data from open datasets.
test_data = datasets.MNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor(),
)


# Define model
class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28 * 28, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, 10),
            nn.ReLU(),
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits


def train_epoch(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset) // session.get_world_size()
    model.train()
    for batch, (X, y) in enumerate(dataloader):

        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            # print(f"batch_size: {len(X)}")
            # print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")
    print(f"Train Rank[{session.get_world_rank()}/{session.get_world_size()}]  loss {loss}")


def validate_epoch(dataloader, model, loss_fn):
    size = len(dataloader.dataset) // session.get_world_size()
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
            # print(f"Valid batch_size:{len(X)}")
    test_loss /= num_batches
    correct /= size
    # print(
    #     f"Test Error:  "
    #     f"Accuracy: {(100 * correct):>0.1f}%, "
    #     f"Avg loss: {test_loss:>8f} \n"
    # )
    print(
        f"Vaild Rank[{session.get_world_rank()}/{session.get_world_size()}]  loss:{test_loss:.4f} correct:{correct:.4f}")
    return test_loss


def train_func(config: Dict):
    batch_size = config["batch_size"]
    lr = config["lr"]
    epochs = config["epochs"]

    worker_batch_size = batch_size // session.get_world_size()

    # Create data loaders.
    train_dataloader = DataLoader(training_data, batch_size=worker_batch_size)
    test_dataloader = DataLoader(test_data, batch_size=worker_batch_size)

    train_dataloader = train.torch.prepare_data_loader(train_dataloader)
    test_dataloader = train.torch.prepare_data_loader(test_dataloader)

    # Create model.
    model = NeuralNetwork()
    print("before wrap model......")
    model = train.torch.prepare_model(model)
    print("after wrap model......")

    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=lr)

    loss_results = []

    for t in range(epochs):
        print("*" * 20, t, "*" * 20)
        train_epoch(train_dataloader, model, loss_fn, optimizer)
        loss = validate_epoch(test_dataloader, model, loss_fn)
        loss_results.append(loss)
        session.report(dict(loss=loss))
        print(f"Outer Valid Rank[{session.get_world_rank()}/{session.get_world_size()}] epoch[{t}]", loss)

    # return required for backwards compatibility with the old API
    # TODO(team-ml) clean up and remove return
    return loss_results


def train_fashion_mnist(num_workers=2, use_gpu=False):
    trainer = TorchTrainer(
        train_func,
        train_loop_config={"lr": 1e-3, "batch_size": 64, "epochs": 4},
        scaling_config=ScalingConfig(num_workers=num_workers, use_gpu=use_gpu),
    )
    result = trainer.fit()
    print(f"Results: {result.metrics}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--address", required=False, type=str, help="the address to use for Ray"
    )
    parser.add_argument(
        "--num_workers",
        "-n",
        type=int,
        default=3,
        help="Sets number of workers for training.",
    )
    parser.add_argument(
        "--use_gpu", action="store_false", default=True, help="Enables GPU training"
    )

    args, _ = parser.parse_known_args()

    ray.init(address="auto")
    train_fashion_mnist(num_workers=args.num_workers, use_gpu=args.use_gpu)
    print(vars(args))

Start command:

python demo_dis_1.py --num_workers 6

it will get stuck in this placeļ¼š

+--------------------------+----------+-----------------+
| Trial name               | status   | loc             |
|--------------------------+----------+-----------------|
| TorchTrainer_1250a_00000 | RUNNING  | 10.51.6.11:8588 |
+--------------------------+----------+-----------------+


== Status ==
Current time: 2023-01-14 09:27:20 (running for 00:25:35.10)
Memory usage on this node: 235.6/503.6 GiB 
Using FIFO scheduling algorithm.
Resources requested: 1.0/160 CPUs, 5.0/6 GPUs, 0.0/335.74 GiB heap, 0.0/147.88 GiB objects (0.0/1.0 accelerator_type:G)
Result logdir: /home/jeff/ray_results/TorchTrainer_2023-01-14_09-01-45
Number of trials: 1/1 (1 RUNNING)
+--------------------------+----------+-----------------+
| Trial name               | status   | loc             |
|--------------------------+----------+-----------------|
| TorchTrainer_1250a_00000 | RUNNING  | 10.51.6.11:8588 |
+--------------------------+----------+-----------------+