[SGD] [Tune] How about the performance of RaySGD compared with pytorch DDP?

Hi there,
I notice that RaySGD provides benchmarking results comparing with some existing solutions for parallel or distributed training. It compares RaySGD with pytorch DataParallel(DP) but not with DistributedDataParallel(DDP). Due to the fact that DDP is superior to DP and RaySGD wraps DDP, why don’t you guys compare RaySGD with DDP?

1 Like

cc @rliaw who worked on the SGD benchmarks

RaySGD wraps DDP, meaning that comparing the two will get exactly the same result.

I thought so, but I did a simple experiment that shows RaySGD is slower than the original DDP.
In detail, I compared 4 Ray workers (on one node) with 4 DDP processes (on one node) on CIFAR-100 with the same settings. The result shows that they have almost the same accuracy and GPU memory usage, but the running time of RaySGD is 1602.94s and the time of DDP is 1380.18s. Is there anything wrong?

OK that’s weird!

Could you help provide a benchmark script for that? It’d be great to understand why this is happening.

@rliaw Thanks for ur attention! Forgot to say the Ray version is 1.1.0. Here is my script.

Ray script:

from torch.utils.data import Dataset
import torch
import torchvision
from torchvision import transforms
import numpy as np
import os
from PIL import Image
import ray
from ray.util.sgd.torch import TrainingOperator
from ray.util.sgd import TorchTrainer
from torch.utils.data import DataLoader


class Cifar100TrainingOperator(TrainingOperator):
    def setup(self, config):

        trans = config["trans"]

        train_dataset = torchvision.datasets.CIFAR100(root="./cifar100_data", train=True, download=True, transform=trans)
        val_dataset = torchvision.datasets.CIFAR100(root="./cifar100_data", train=False, download=True, transform=trans)

        train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
        val_loader = DataLoader(dataset=val_dataset, batch_size=batch_size)

        if criterion is not None and scheduler is not None:
            self.model, self.optimizer, self.criterion, self.scheduler = self.register(
                models=model,
                optimizers=optimizer,
                criterion=criterion,
                schedulers=scheduler)
        elif criterion is not None:
            self.model, self.optimizer, self.criterion = self.register(
                models=model,
                optimizers=optimizer,
                criterion=criterion)
        elif scheduler is not None:
            self.model, self.optimizer, self.scheduler = self.register(
                models=model,
                optimizers=optimizer,
                schedulers=scheduler)
        else:
            self.model, self.optimizer = self.register(
                models=model,
                optimizers=optimizer)

        self.register_data(
            train_loader=train_loader,
            validation_loader=val_loader)


if __name__ == '__main__':
    # for local
    ray.init()
    # for cluster
    # ray.init(address='auto')

    batch_size = 512
    epochs = 200
    scheduler = None
    trans = transforms.Compose([
        transforms.Resize((28, 28)),
        transforms.ToTensor(),
    ])
    import timm
    model = timm.create_model('resnet50', pretrained=False, num_classes=100)
    optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
    criterion = torch.nn.CrossEntropyLoss()

    trainer = TorchTrainer(
        training_operator_cls=Cifar100TrainingOperator,
        num_workers=4,
        config={"trans": trans},
        use_gpu=True,
    )

    train_loss = []
    val_loss = []
    val_acc = []

    import time
    start_time = time.time()

    for i in range(epochs):
        train_info = trainer.train()
        print("train:", train_info)
        train_loss.append(train_info['train_loss'])

        eval_info = trainer.validate()
        print("eval:", eval_info, '\n')
        val_loss.append(eval_info['val_loss'])
        val_acc.append(eval_info['val_accuracy'])

    end_time = time.time()
    used_time = str(round(end_time - start_time, 2))

    import matplotlib.pyplot as plt

    plt.figure(dpi=300, figsize=(30, 8))
    plt.subplot(1, 3, 1)
    plt.plot(train_loss)
    plt.xlabel("epoch")
    plt.ylabel("train_loss")
    plt.subplot(1, 3, 2)
    plt.plot(val_loss)
    plt.xlabel("epoch")
    plt.ylabel("val_loss")
    plt.subplot(1, 3, 3)
    plt.plot(val_acc)
    plt.xlabel("epoch")
    plt.ylabel("val_acc")
    plt.savefig("fig_ray_cifar100_resnet50-" + str(epochs) + "-" + str(batch_size) + "-" +
                used_time + "s.png")

    print("============= Training finished ==============")
    print("total time:", used_time)

DDP script:

from torch.utils.data import Dataset
import torch
from torchvision import transforms
import torchvision
import numpy as np
import os
from PIL import Image
from torch.utils.data import DataLoader
import torch
import argparse

import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP


if __name__ == '__main__':
    gpu_num = torch.cuda.device_count()
    dist.init_process_group(backend='nccl')
    local_rank = torch.distributed.get_rank() % gpu_num
    torch.cuda.set_device(local_rank)
    device = torch.device(f"cuda:{local_rank}")
    assert torch.distributed.is_nccl_available() is True
    print("GPU num:", gpu_num, "device:", device)

    batch_size = 512
    epochs = 200

    trans = transforms.Compose([
        transforms.Resize((28, 28)),
        transforms.ToTensor(),
    ])

    train_dataset = torchvision.datasets.CIFAR100(root="./cifar100_data", train=True, download=True, transform=trans)
    val_dataset = torchvision.datasets.CIFAR100(root="./cifar100_data", train=False, download=True, transform=trans)

    train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
    val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset)

    train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, sampler=train_sampler)
    val_loader = DataLoader(dataset=val_dataset, batch_size=batch_size, sampler=val_sampler)

    import timm
    model = timm.create_model('resnet50', pretrained=False, num_classes=100).to(device)
    model = DDP(model, device_ids=[local_rank], output_device=local_rank)#, find_unused_parameters=True)
    optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
    criterion = torch.nn.CrossEntropyLoss()

    train_loss = []
    val_loss = []
    val_acc = []

    import time

    start_time = time.time()

    for i in range(epochs):
        model.train()
        ######### train #########
        train_loader.sampler.set_epoch(i)

        epoch_train_loss = 0
        for input, target in train_loader:
            input, target = input.to(device), target.to(device)
            # compute output
            output = model(input)

            # compute gradient and do SGD step
            optimizer.zero_grad()
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            batch_train_loss = loss.item()
            epoch_train_loss += batch_train_loss
        train_loss.append(epoch_train_loss / len(train_loader))
        print("avg_batch_train_loss:", epoch_train_loss / len(train_loader))

        model.eval()
        ######### val #########
        correct = 0
        total = 0
        with torch.no_grad():
            epoch_val_loss = 0
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                # forward
                out = model(images)
                loss = criterion(out, labels)
                batch_val_loss = loss.item()
                epoch_val_loss += batch_val_loss
                _, pred = torch.max(out, 1)
                correct += (pred == labels).sum().item()
                total += labels.size(0)
            val_loss.append(epoch_val_loss / len(val_loader))
            accuracy = float(correct) / total
            val_acc.append(accuracy)
            print("avg_batch_val_loss:", epoch_val_loss / len(val_loader))
            print("accuracy:", accuracy)

    end_time = time.time()
    used_time = str(round(end_time - start_time, 2))

    if dist.get_rank() == 0:
        import matplotlib.pyplot as plt
        plt.figure(dpi=300, figsize=(30, 8))
        plt.subplot(1, 3, 1)
        plt.plot(train_loss)
        plt.xlabel("epoch")
        plt.ylabel("train_loss")
        plt.subplot(1, 3, 2)
        plt.plot(val_loss)
        plt.xlabel("epoch")
        plt.ylabel("val_loss")
        plt.subplot(1, 3, 3)
        plt.plot(val_acc)
        plt.xlabel("epoch")
        plt.ylabel("val_acc")
        plt.savefig("fig_ddp_cifar100_resnet50-" + str(epochs) + "-" + str(batch_size) + "-" +
                    used_time + "s.png")

        print("============= Training finished ==============")
        print("total time:", used_time)

command line:

python -m torch.distributed.launch --nproc_per_node 4 experiments/ddp_cifar100_resnet50.py

OK thanks! How many CPUs are on your machine? As a first step, could you try setting this in the initialization_hook of raySGD – OMP_NUM_THREADS=[total cpus on machine / 4]?

I tried this but it doesn’t matter. My machine has 4 GPUs and 40 CPUs. I followed single-GPU per process pattern. Further experimental results are as follows:

Ray:
OMP_NUM_THREADS = 1 (default) total time: 1225.53
OMP_NUM_THREADS = 10 total time: 1230.6

DDP:
OMP_NUM_THREADS = 1 (default) total time: 984.03
OMP_NUM_THREADS = 10 total time: 982.89

Ah, can you also try setting the backend of RaySGD to be NCCL?

I think the default backend is NCCL because the arg use_gpu is True.

backend ( string ) – backend used by distributed PyTorch. Currently support “nccl”, “gloo”, and “auto”. If “auto”, RaySGD will automatically use “nccl” if use_gpu is True, and “gloo” otherwise.

BTW, I explicitly specified the backend as NCCL:

trainer = TorchTrainer(
        training_operator_cls=Cifar100TrainingOperator,
        initialization_hook=initialization_hook,
        num_workers=4,
        config={"trans": trans},
        use_gpu=True,
        backend='nccl',
    )

The total time is 1227.86s.

Thanks!

Could you also use trainer.train(profile=True) and post the output of that? This will give some lower-level profiling numbers (i.e., per batch, data loading, etc). It’d be great to also get the equivalent for those on the DDP side.

Finally, can you check nvidia-smi when the workload is running, to check that both are running on same GPUs, using the same amount of memory, etc

OK, thanks. I tried printing detailed training and validation info. A part of the results are as follows:

RaySGD:

epoch: 147
current epoch training time: 5.65
train: {'num_samples': 50000, 'epoch': 148.0, 'batch_count': 25.0, 'train_loss': 0.0005675839141034521     , 'last_train_loss': 0.00011753707804018632, 'profile': {'mean_train_epoch_s': 5.641417980194092, 'mea     n_fwd_s': 0.011049604415893555, 'mean_grad_s': 0.13608250617980958, 'mean_apply_s': 0.0202226638793945     3}}
current epoch val time: 0.49
eval: {'num_samples': 10000, 'batch_count': 5.0, 'val_loss': 4.831029027557373, 'last_val_loss': 4.938     1022453308105, 'val_accuracy': 0.3411, 'last_val_accuracy': 0.3346238938053097, 'profile': {'mean_vali     dation_s': 0.47080397605895996, 'mean_eval_fwd_s': 0.006463813781738281}}

epoch: 148
current epoch training time: 5.66
train: {'num_samples': 50000, 'epoch': 149.0, 'batch_count': 25.0, 'train_loss': 0.0005811023716226918     , 'last_train_loss': 0.0001186984281957848, 'profile': {'mean_train_epoch_s': 5.6611151695251465, 'mea     n_fwd_s': 0.011197566986083984, 'mean_grad_s': 0.13827722072601317, 'mean_apply_s': 0.0195135593414306     65}}
current epoch val time: 0.48
eval: {'num_samples': 10000, 'batch_count': 5.0, 'val_loss': 4.837075037384033, 'last_val_loss': 4.945     190906524658, 'val_accuracy': 0.3408, 'last_val_accuracy': 0.3274336283185841, 'profile': {'mean_valid     ation_s': 0.47136616706848145, 'mean_eval_fwd_s': 0.006566476821899414}}

epoch: 149
current epoch training time: 5.62
train: {'num_samples': 50000, 'epoch': 150.0, 'batch_count': 25.0, 'train_loss': 0.000528241263361997,      'last_train_loss': 0.002027789820203907, 'profile': {'mean_train_epoch_s': 5.612389087677002, 'mean_f     wd_s': 0.011083316802978516, 'mean_grad_s': 0.13872356414794923, 'mean_apply_s': 0.01943700313568115}}
current epoch val time: 0.51
eval: {'num_samples': 10000, 'batch_count': 5.0, 'val_loss': 4.838133937644958, 'last_val_loss': 4.939     170241355896, 'val_accuracy': 0.3403, 'last_val_accuracy': 0.3313053097345133, 'profile': {'mean_valid     ation_s': 0.46913623809814453, 'mean_eval_fwd_s': 0.006501960754394531}}

epoch: 150
current epoch training time: 5.66
train: {'num_samples': 50000, 'epoch': 151.0, 'batch_count': 25.0, 'train_loss': 0.0006256205962673994     , 'last_train_loss': 0.00013180099085730035, 'profile': {'mean_train_epoch_s': 5.652625560760498, 'mea     n_fwd_s': 0.01103370189666748, 'mean_grad_s': 0.13629908561706544, 'mean_apply_s': 0.01988470554351806     6}}
current epoch val time: 0.49
eval: {'num_samples': 10000, 'batch_count': 5.0, 'val_loss': 4.8417091506958005, 'last_val_loss': 4.94     0457344055176, 'val_accuracy': 0.3425, 'last_val_accuracy': 0.33407079646017696, 'profile': {'mean_val     idation_s': 0.4683079719543457, 'mean_eval_fwd_s': 0.006442117691040039}}

DDP(The printing format is a little bit messy as 4 processes’ results are printed out, but I think it does not matter in terms of running time):

batch time: 0.07274174690246582
current epoch training time: 4.46
avg_batch_train_loss: 0.014519559629261493
current epoch val time: 0.45
avg_batch_val_loss: 4.392485427856445
accuracy: 0.3104
epoch: 147
current epoch val time: 0.46
avg_batch_val_loss: 4.350419330596924
accuracy: 0.3112
epoch: 147
current epoch val time: 0.46
avg_batch_val_loss: 4.28261432647705
accuracy: 0.3252
epoch: 147
current epoch val time: 0.46
avg_batch_val_loss: 4.28979377746582
accuracy: 0.3308
epoch: 147
batch time: 0.07245039939880371
current epoch training time: 4.45batch time:
 avg_batch_train_loss: 0.072566032409667970.006455761110410094

current epoch training time:batch time:  4.46
0.07230877876281738avg_batch_train_loss:
 0.005492587890475988
current epoch training time: 4.46
avg_batch_train_loss: 0.006644705887883902
batch time: 0.07229733467102051
current epoch training time: 4.45
avg_batch_train_loss: 0.005790041536092758
current epoch val time: 0.46
avg_batch_val_loss: 4.38764705657959
accuracy: 0.3264
epoch: 148
current epoch val time: 0.46
avg_batch_val_loss: 4.289935874938965
accuracy: 0.3268
epoch: 148
current epoch val time: 0.46
avg_batch_val_loss: 4.300405406951905
accuracy: 0.3248
epoch: 148
current epoch val time: 0.46
avg_batch_val_loss: 4.231341171264648
accuracy: 0.336
epoch: 148
batch time:batch time:  0.072268009185791020.07186603546142578

current epoch training time: 4.44
avg_batch_train_loss: 0.002534276871010661
batch time: current epoch training time:0.0719599723815918
4.44
avg_batch_train_loss: 0.0030480825062841176
batch time: 0.072052001953125
current epoch training time: 4.44
avg_batch_train_loss: 0.0023561480874195696
current epoch training time: 4.43
avg_batch_train_loss: 0.003187773786485195
current epoch val time: 0.46
avg_batch_val_loss: 4.287508010864258
accuracy: 0.3312
current epoch val time:epoch:  0.46149

avg_batch_val_loss: 4.385404491424561
accuracy: 0.3292
epoch: 149
current epoch val time: 0.46
avg_batch_val_loss: 4.317112636566162
accuracy: 0.33
epoch: 149
current epoch val time: 0.46
avg_batch_val_loss: 4.2344895362854
accuracy: 0.34
epoch: 149
batch time: 0.07316398620605469
current epoch training time: 4.47
avg_batch_train_loss: 0.0018080707266926765
batch time: 0.07324981689453125
batch time: current epoch training time:0.07298779487609863
4.46
avg_batch_train_loss: 0.002247629172634333
current epoch training time: 4.46
avg_batch_train_loss: 0.001824645777232945
batch time: 0.07291889190673828
current epoch training time: 4.47
avg_batch_train_loss: 0.002398280827328563
current epoch val time: 0.45
avg_batch_val_loss: 4.373600196838379
accuracy: 0.3328

To sum up, I trained both RaySGD and DDP 200 epochs with the same settings.
RaySGD’s training epoch time is ~5.64s, val time is ~0.49s.
total time = (5.64 + 0.49) * 200 = 1226s


DDP’s training epoch time is ~4.45s, val time is ~0.46s.
total time = (4.45 + 0.46) * 200 = 982s
nvidia-smi img of DDP will be put below because it reminds me “new users can only put one embedded media item in a post.”

The running time matches the above experimental results, and both have roughly the same GPU memory usage.

DDP:

Hi Richard, do you know what’s wrong? Please let me know if you have solutions.

Thanks for doing this! This is really helpful. Sorry for the slow reply!

What is batch time: 0.07316398620605469 measuring for DDP?

In RaySGD we have the following (from epoch 147):

  • mean_fwd_s: 0.011049604415893555
  • mean_grad_s: 0.13608250617980958
  • mean_apply_s: 0.02022266387939453
  • num_batches = 25

Thus, the time for training (excluding data loading) is 25 * 0.136 = 4.16. So something is taking 1.5 extra seconds (perhaps it is a data loader related latency bump). Is it possible for you to help investigate this?

Reaaaaally appreciate your help!

This is the time of the last batch of epochs. It seems to be useless so I delete it.

Then, I mimicked Ray’s code and separately tested the average time of DDP: forward_time, backward_time, and apply_time.

Here’s the Ray code:

        # Compute output.
        with self.timers.record("fwd"):
            output = model(*features)
            loss = criterion(output, target)

        # Compute gradients in a backward pass.
        with self.timers.record("grad"):
            optimizer.zero_grad()
            if self.use_fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

        # Call step of optimizer to update model params.
        with self.timers.record("apply"):
            optimizer.step()

My DDP code:

        fwd_time = bwd_time = apply_time = 0
        for x, y in train_loader:
            batch_t1 = time.time()
            x, y = x.to(device), y.to(device)
            # forward
            fwd_t1 = time.time()
            pred = model(x)
            loss = criterion(pred, y)
            fwd_t2 = time.time()
            fwd_time += fwd_t2 - fwd_t1
            # backward
            bwd_t1 = time.time()
            optimizer.zero_grad()
            loss.backward()
            bwd_t2 = time.time()
            bwd_time += bwd_t2 - bwd_t1
            # apply
            apply_t1 = time.time()
            optimizer.step()
            apply_t2 = time.time()
            apply_time += apply_t2 - apply_t1

            batch_train_loss = loss.cpu().item() * len(x)
            epoch_train_loss += batch_train_loss
            batch_t2 = time.time()
            #print("batch_loss:", batch_train_loss)
            #print("batch_time:", batch_t2 - batch_t1)
        
        if dist.get_rank() == 0:
            print("fwd_time:", fwd_time / len(train_loader))
            print("bwd_time:", bwd_time / len(train_loader))
            print("apply_time:", apply_time / len(train_loader))

Part of the results of my DDP is as follows:

fwd_time: 0.01103464126586914
bwd_time: 0.02290154457092285
apply_time: 0.09678736686706543

fwd_time: 0.00990945816040039
bwd_time: 0.01781126022338867
apply_time: 0.10193952560424804

fwd_time: 0.009940462112426758
bwd_time: 0.017059249877929686
apply_time: 0.10256013870239258

fwd_time: 0.009918899536132812
bwd_time: 0.01674846649169922
apply_time: 0.1029355525970459

fwd_time: 0.00994701385498047
bwd_time: 0.016628246307373046
apply_time: 0.10304969787597656

fwd_time: 0.009907283782958985
bwd_time: 0.016851835250854492
apply_time: 0.10303903579711914

fwd_time: 0.009973907470703125
bwd_time: 0.016728286743164063
apply_time: 0.10317913055419922

fwd_time: 0.00990736961364746
bwd_time: 0.016668014526367188
apply_time: 0.10340714454650879

I noticed that most of the time is used by backward in Ray but apply in DDP. Is there anything wrong?

Besides, maybe here’s a miscalculation Thus, the time for training (excluding data loading) is 25 * 0.136 = 4.16. Do you mean 25 * (0.011 + 0.136 + 0.02) = 4.175 ≈ 4.16?
If so, DDP’s training time: 25 * (0.01 + 0.017 + 0.103) = 3.25. It is still faster than Ray.

Even if the extra time is used by data loading or anything else, DDP should take the same amount of time. After calculation, they did use pretty much the same extra time.

I found my previous claim is wrong because I didn’t calculate the correct time. Now, I add torch.cuda.synchronize() before recording the time. The results are as follows:

fwd_time: 0.021761627197265626
bwd_time: 0.09805940628051758
apply_time: 0.017625408172607424

fwd_time: 0.021732749938964843
bwd_time: 0.09838109970092773
apply_time: 0.017648420333862304

fwd_time: 0.021735153198242187
bwd_time: 0.098248291015625
apply_time: 0.017606115341186522

fwd_time: 0.02169997215270996
bwd_time: 0.09806147575378418
apply_time: 0.017623739242553713

fwd_time: 0.021713523864746093
bwd_time: 0.09819583892822266
apply_time: 0.017622852325439455

fwd_time: 0.021727790832519533
bwd_time: 0.0983795166015625
apply_time: 0.017629175186157225

DDP’s epoch training time: 25 * (0.0217 + 0.098 + 0.018) ≈ 3.44.
DDP’s backward uses most of the time like Ray. BUT it is still faster than Ray.

My current code tries to make the code more similar to the DDP:

the result is:

  num_samples    epoch    batch_count    train_loss    last_train_loss    mean_train_epoch_s    mean_fwd_s    mean_grad_s    mean_apply_s
-------------  -------  -------------  ------------  -----------------  --------------------  ------------  -------------  --------------
        50000        1             25       4.29725            4.15469               5.74925      0.017165      0.0912009       0.0180373
        50000        2             25       3.92961            3.73614               5.74785     0.0161231      0.0949066        0.017922
        50000        3             25       3.45716            3.28966               5.69306     0.0162024      0.0989576       0.0173396
        50000        4             25       3.09345            3.09619               5.60449     0.0153117      0.0983784       0.0168418
        50000        5             25       2.78035            2.69492               5.61458     0.0155574       0.098223       0.0172038

It seems like the performance (fwd, back, apply) is now the same as your DDP numbers.

Thanks! Do you mean moving these lines inside of the operator makes sense?

model = timm.create_model('resnet50', pretrained=False, num_classes=100).to(self.device)
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
criterion = torch.nn.CrossEntropyLoss()

Did you try moving the part of instantiating model, optimizer, and criterion outside of the operator? In my experiments, they took almost the same time whether moving the above code inside or outside.