Scaling Ray Train in PyTorch with multiple GPUs per Worker: AttributeError Issue

Greetings, and thanks for the great library.

I am new to Ray. My goal is to run Ray Tune for HPO with PyTorch and multiple GPUs per trial. As a first step, I’m trying to run this tutorial. I have slightly adapted the code to include the missing import ray line, and increased the resources per worker to two GPUs:

scaling_config = ScalingConfig(
    num_workers=2,
    use_gpu=True,
    resources_per_worker={"CPU": 2, "GPU": 2},
)
import tempfile
import torch
from torchvision.models import resnet18
from torchvision.datasets import FashionMNIST
from torchvision.transforms import ToTensor, Normalize, Compose
from torch.utils.data import DataLoader
from torch.optim import Adam
from torch.nn import CrossEntropyLoss
import ray
from ray.train.torch import TorchTrainer
from ray.train import ScalingConfig, Checkpoint


def train_func(config):
    # Model, Loss, Optimizer
    model = resnet18(num_classes=10)
    model.conv1 = torch.nn.Conv2d(
        1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
    )
    # [1] Prepare model.
    model = ray.train.torch.prepare_model(model)
    criterion = CrossEntropyLoss()
    optimizer = Adam(model.parameters(), lr=0.001)

    # Data
    transform = Compose([ToTensor(), Normalize((0.5,), (0.5,))])
    train_data = FashionMNIST(
        root="./data", train=True, download=True, transform=transform
    )
    train_loader = DataLoader(train_data, batch_size=128, shuffle=True)
    # [2] Prepare dataloader.
    train_loader = ray.train.torch.prepare_data_loader(train_loader)

    # Training
    for epoch in range(10):
        for images, labels in train_loader:
            outputs = model(images)
            loss = criterion(outputs, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        checkpoint_dir = tempfile.gettempdir()
        checkpoint_path = checkpoint_dir + "/model.checkpoint"
        torch.save(model.state_dict(), checkpoint_path)
        # [3] Report metrics and checkpoint.
        ray.train.report(
            {"loss": loss.item()}, checkpoint=Checkpoint.from_directory(checkpoint_dir)
        )


# [4] Configure scaling and resource requirements.
scaling_config = ScalingConfig(
    num_workers=2,
    use_gpu=True,
    resources_per_worker={"CPU": 2, "GPU": 2},
)

# [5] Launch distributed training job.
trainer = TorchTrainer(train_func, scaling_config=scaling_config)
result = trainer.fit()

However, I’m getting the following error:

2023-10-09 12:15:47,848 INFO worker.py:1642 -- Started a local Ray instance.
2023-10-09 12:15:50,003 INFO tune.py:228 -- Initializing Ray automatically. For cluster usage or custom Ray initialization, call `ray.init(...)` before `Trainer(...)`.
2023-10-09 12:15:50,004 INFO tune.py:654 -- [output] This will use the new output engine with verbosity 1. To disable the new output and use the legacy output engine, set the environment variable RAY_AIR_NEW_OUTPUT=0. For more information, please see https://github.com/ray-project/ray/issues/36949

View detailed results here: /gpfs/data/huo-lab/Image/annawoodard/ray_results/TorchTrainer_2023-10-09_12-15-46
To visualize your results with TensorBoard, run: `tensorboard --logdir /gpfs/data/huo-lab/Image/annawoodard/ray_results/TorchTrainer_2023-10-09_12-15-46`

Training started without custom configuration.
(TorchTrainer pid=28494) Starting distributed worker processes: ['28628 (10.50.47.81)', '28629 (10.50.47.81)']
(RayTrainWorker pid=28628) Setting up process group for: env:// [rank=0, world_size=2]
(RayTrainWorker pid=28628) Moving model to device: cuda:0
(RayTrainWorker pid=28628) Wrapping provided model in DistributedDataParallel.
(RayTrainWorker pid=28628) Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
(RayTrainWorker pid=28628) Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to ./data/FashionMNIST/raw/train-images-idx3-ubyte.gz
  0%|          | 0/26421880 [00:00<?, ?it/s]
  0%|          | 32768/26421880 [00:00<01:23, 316184.94it/s]
  0%|          | 0/26421880 [00:00<?, ?it/s]
 23%|██▎       | 6193152/26421880 [00:05<00:15, 1291076.03it/s] [repeated 94x across cluster] (Ray deduplicates logs by default. Set RAY_DEDUP_LOGS=0 to disable log deduplication, or see https://docs.ray.io/en/master/ray-observability/ray-logging.html#log-deduplication for more options.)
 50%|████▉     | 13139968/26421880 [00:07<00:07, 1820380.21it/s]
 68%|██████▊   | 17858560/26421880 [00:10<00:04, 1923702.96it/s] [repeated 88x across cluster]
 91%|█████████▏| 24150016/26421880 [00:13<00:01, 1914159.79it/s]
 92%|█████████▏| 24346624/26421880 [00:13<00:01, 1919435.67it/s]
 93%|█████████▎| 24543232/26421880 [00:13<00:00, 1921696.77it/s]
 94%|█████████▎| 24739840/26421880 [00:13<00:00, 1922382.87it/s]
 94%|█████████▍| 24936448/26421880 [00:14<00:00, 1925794.84it/s]
 95%|█████████▌| 25133056/26421880 [00:14<00:00, 1928056.54it/s]
 96%|█████████▌| 25329664/26421880 [00:14<00:00, 1929893.53it/s]
 97%|█████████▋| 25526272/26421880 [00:14<00:00, 1929571.01it/s]
 97%|█████████▋| 25722880/26421880 [00:14<00:00, 1930254.40it/s]
 98%|█████████▊| 25919488/26421880 [00:14<00:00, 1924792.51it/s]
 99%|█████████▉| 26116096/26421880 [00:14<00:00, 1923559.16it/s]
100%|█████████▉| 26312704/26421880 [00:14<00:00, 1926102.62it/s]
(RayTrainWorker pid=28629) Extracting ./data/FashionMNIST/raw/train-images-idx3-ubyte.gz to ./data/FashionMNIST/raw
(RayTrainWorker pid=28629) Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
(RayTrainWorker pid=28629) Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to ./data/FashionMNIST/raw/train-images-idx3-ubyte.gz
100%|██████████| 26421880/26421880 [00:14<00:00, 1780504.50it/s]
(RayTrainWorker pid=28629) 
(RayTrainWorker pid=28629) Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
 73%|███████▎  | 19398656/26421880 [00:15<00:05, 1336237.36it/s] [repeated 77x across cluster]
(RayTrainWorker pid=28629) Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw/train-labels-idx1-ubyte.gz
  0%|          | 0/29515 [00:00<?, ?it/s]
(RayTrainWorker pid=28629) Extracting ./data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw
(RayTrainWorker pid=28629) 
(RayTrainWorker pid=28629) Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
100%|██████████| 29515/29515 [00:00<00:00, 225804.18it/s]
(RayTrainWorker pid=28629) Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to ./data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz
  0%|          | 0/4422102 [00:00<?, ?it/s]
 94%|█████████▍| 4161536/4422102 [00:02<00:00, 2213862.00it/s]
 99%|█████████▉| 4390912/4422102 [00:02<00:00, 2215275.92it/s]
(RayTrainWorker pid=28629) Extracting ./data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to ./data/FashionMNIST/raw
100%|██████████| 4422102/4422102 [00:02<00:00, 1847914.53it/s]
 91%|█████████▏| 24150016/26421880 [00:18<00:01, 1425706.76it/s]
(RayTrainWorker pid=28629) 
(RayTrainWorker pid=28629) Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
 92%|█████████▏| 24313856/26421880 [00:18<00:01, 1470598.71it/s]
 93%|█████████▎| 24477696/26421880 [00:18<00:01, 1504863.37it/s]
(RayTrainWorker pid=28629) Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz
 93%|█████████▎| 24641536/26421880 [00:18<00:01, 1535790.03it/s]
 94%|█████████▍| 24805376/26421880 [00:19<00:01, 1559641.53it/s]
  0%|          | 0/5148 [00:00<?, ?it/s]
(RayTrainWorker pid=28629) Extracting ./data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/FashionMNIST/raw
(RayTrainWorker pid=28629) 
2023-10-09 12:16:17,232 ERROR tune_controller.py:1502 -- Trial task failed for trial TorchTrainer_7e641_00000
Traceback (most recent call last):
  File "/gpfs/data/huo-lab/Image/annawoodard/miniconda3/envs/medtectron/lib/python3.10/site-packages/ray/air/execution/_internal/event_manager.py", line 110, in resolve_future
    result = ray.get(future)
  File "/gpfs/data/huo-lab/Image/annawoodard/miniconda3/envs/medtectron/lib/python3.10/site-packages/ray/_private/auto_init_hook.py", line 24, in auto_init_wrapper
    return fn(*args, **kwargs)
  File "/gpfs/data/huo-lab/Image/annawoodard/miniconda3/envs/medtectron/lib/python3.10/site-packages/ray/_private/client_mode_hook.py", line 103, in wrapper
    return func(*args, **kwargs)
  File "/gpfs/data/huo-lab/Image/annawoodard/miniconda3/envs/medtectron/lib/python3.10/site-packages/ray/_private/worker.py", line 2547, in get
    raise value.as_instanceof_cause()
ray.exceptions.RayTaskError(AttributeError): ray::_Inner.train() (pid=28494, ip=10.50.47.81, actor_id=78d70ddaa4899021bb074cd201000000, repr=TorchTrainer)
  File "/gpfs/data/huo-lab/Image/annawoodard/miniconda3/envs/medtectron/lib/python3.10/site-packages/ray/tune/trainable/trainable.py", line 400, in train
    raise skipped from exception_cause(skipped)
  File "/gpfs/data/huo-lab/Image/annawoodard/miniconda3/envs/medtectron/lib/python3.10/site-packages/ray/train/_internal/utils.py", line 54, in check_for_failure
    ray.get(object_ref)
ray.exceptions.RayTaskError(AttributeError): ray::_RayTrainWorker__execute.get_next() (pid=28629, ip=10.50.47.81, actor_id=11c8474abebbccccd941581101000000, repr=<ray.train._internal.worker_group.RayTrainWorker object at 0x7fb33070d4e0>)
  File "/gpfs/data/huo-lab/Image/annawoodard/miniconda3/envs/medtectron/lib/python3.10/site-packages/ray/train/_internal/worker_group.py", line 33, in __execute
    raise skipped from exception_cause(skipped)
  File "/gpfs/data/huo-lab/Image/annawoodard/miniconda3/envs/medtectron/lib/python3.10/site-packages/ray/train/_internal/utils.py", line 129, in discard_return_wrapper
    train_func(*args, **kwargs)
  File "/gpfs/data/huo-lab/Image/annawoodard/medtectron/train_foo.py", line 32, in train_func
    train_loader = ray.train.torch.prepare_data_loader(train_loader)
  File "/gpfs/data/huo-lab/Image/annawoodard/miniconda3/envs/medtectron/lib/python3.10/site-packages/ray/train/torch/train_loop_utils.py", line 142, in prepare_data_loader
    return get_accelerator(_TorchAccelerator).prepare_data_loader(
  File "/gpfs/data/huo-lab/Image/annawoodard/miniconda3/envs/medtectron/lib/python3.10/site-packages/ray/train/torch/train_loop_utils.py", line 461, in prepare_data_loader
    data_loader = _WrappedDataLoader(data_loader, device, auto_transfer)
  File "/gpfs/data/huo-lab/Image/annawoodard/miniconda3/envs/medtectron/lib/python3.10/site-packages/ray/train/torch/train_loop_utils.py", line 514, in __init__
    self._auto_transfer = auto_transfer if device.type == "cuda" else False
AttributeError: 'list' object has no attribute 'type'

Is this the proper way to scale up to two GPUs per worker? If not, can you point me to an example of how to do so? Note that for various reasons, I prefer to use vanilla PyTorch and not PyTorch Lightning. I am using version 2.7.0. Many thanks!

Hi @annawoodard , with Ray TorchTrainer, if you want to use two GPUs for DDP training, you should specify ScalingConfig as below.

ScalingConfig(num_workers=2, use_gpu=True, resources_per_worker={"GPU":1})

(Actually by default, use_gpu=True will automatically allocate 1 GPU per worker. No need to specify resources_per_worker here)