[AMP]Mixed precision training is slower than default precision

Hi,
I try to use mixed precision training via ray train, but after using amp, the training time is much longer. But I try to train resnet50 with the same dataset without ray, amp can really get performance improvement. Could you give me some advices why this happen? Below is my code:

import argparse
from typing import Dict
from ray.air import session

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
from torchvision.models import resnet50

import ray.train as train
from ray.train.torch import TorchTrainer
from ray.air.config import ScalingConfig

import torchvision.transforms as transforms
import os
from filelock import FileLock
from torchvision.datasets import CIFAR10



def train_epoch(dataloader, model, loss_fn, optimizer,amp = False):
    size = len(dataloader.dataset) // session.get_world_size()
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()
        if(amp):
            train.torch.backward(loss)
        else:
            loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")


def validate_epoch(dataloader, model, loss_fn):
    size = len(dataloader.dataset) // session.get_world_size()
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(
        f"Test Error: \n "
        f"Accuracy: {(100 * correct):>0.1f}%, "
        f"Avg loss: {test_loss:>8f} \n"
    )
    return test_loss


def train_func(config: Dict):
    batch_size = config["batch_size"]
    lr = config["lr"]
    epochs = config["epochs"]
    amp = config["amp"]

    worker_batch_size = batch_size // session.get_world_size()


     # Create model.
    # model = NeuralNetwork()
    if(amp):
        print('amp activated')
        train.torch.accelerate(amp=True)


    model = resnet50()

    model = train.torch.prepare_model(model)

    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=lr)
    if(amp):
        optimizer = train.torch.prepare_optimizer(optimizer)


    # Load in training and validation data.
    transform_train = transforms.Compose(
        [   
            transforms.Resize([224,224]),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
        ]
    )  # meanstd transformation

    transform_test = transforms.Compose(
        [
            transforms.Resize([224,224]),
            transforms.ToTensor(),
        ]
    )

    data_dir = config.get("data_dir", os.path.expanduser("~/data"))
    os.makedirs(data_dir, exist_ok=True)
    with FileLock(os.path.join(data_dir, ".ray.lock")):
        train_dataset = CIFAR10(
            root=data_dir, train=True, download=False, transform=transform_train
        )
        validation_dataset = CIFAR10(
            root=data_dir, train=False, download=False, transform=transform_test
        )

    train_dataloader = DataLoader(train_dataset, batch_size=worker_batch_size,pin_memory=True)
    test_dataloader = DataLoader(validation_dataset, batch_size=worker_batch_size,pin_memory=True)

    train_dataloader = train.torch.prepare_data_loader(train_dataloader)
    test_dataloader = train.torch.prepare_data_loader(test_dataloader)
    worker_batch_size = config["batch_size"] // session.get_world_size()

    

   

    for _ in range(epochs):
        train_epoch(train_dataloader, model, loss_fn, optimizer,amp=amp)
        loss = validate_epoch(test_dataloader, model, loss_fn)
        session.report(dict(loss=loss))


def train_cifar10(num_workers=2, use_gpu=False,enable_amp=False):
    trainer = TorchTrainer(
        train_loop_per_worker=train_func,
        train_loop_config={"lr": 1e-3, "batch_size": 128*num_workers, "epochs": 1,"amp":enable_amp},
        scaling_config=ScalingConfig(num_workers=num_workers, use_gpu=use_gpu),
    )
    result = trainer.fit()
    print(f"Last result: {result.metrics}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--address", required=False, type=str, help="the address to use for Ray"
    )
    parser.add_argument(
        "--num-workers",
        "-n",
        type=int,
        default=1,
        help="Sets number of workers for training.",
    )
    parser.add_argument(
        "--use-gpu", action="store_true", default=False, help="Enables GPU training"
    )
    parser.add_argument(
        "--smoke-test",
        action="store_true",
        default=False,
        help="Finish quickly for testing.",
    )
    parser.add_argument(
        "--data-dir",
        required=False,
        type=str,
        default="~/data",
        help="Root directory for storing downloaded dataset.",
    )
    parser.add_argument(
        "--amp",
        action="store_true",
        default=False,
        help="Enable automatic mixed precision training.",
    )

    args, _ = parser.parse_known_args()

    import ray

    if args.smoke_test:
        # 2 workers + 1 for trainer.
        ray.init(num_cpus=3)
        train_cifar10()
    else:
        ray.init(address=args.address)
        train_cifar10(num_workers=args.num_workers, use_gpu=args.use_gpu, enable_amp=args.amp)

you could run by python torch_cifar10_example.py --use-gpu
or python torch_cifar10_example.py --use-gpu --amp

What is the GPU type and number you are using?

I have 4 A100 and I have tried with 1,2,4 GPUs,but they all get performance degradation

Will try to reproduce and get back to you. We have a test in CI that checks if AMP performance is better, so this is interesting.

I ran the script with an A10 and AMP provided faster runtime and lower memory usage. Gonna see if I can get an A100.

I ran the scirpt with an A100 again, and still get the same result. In the dashboard, I found that when I activate AMP, GPU utilization drop down from 100% to nearly 45%.

What are the torch and ray versions you are using?

I can actually replicate this on an A100, you are correct (AMP taking more time than fp32). That being said, I do not think this is related to Ray Train or Ray in general. I believe this issue from PyTorch should shed some light on the situation - torch.cuda.amp cannot speed up on A100 · Issue #57806 · pytorch/pytorch · GitHub

I have tried disabling tf32 support in the train function, and that slowed down non-AMP training considerably. Here are the times I got:

  • fp32 with tf32 (default): 106s
  • AMP with tf32: 112s
  • fp32 without tf32: 121s

Code to disable tf32:

def train_func(config: Dict):
    import torch.backends.cuda

    torch.backends.cuda.matmul.allow_tf32 = False
    torch.backends.cudnn.allow_tf32 = False

We may want to look into taking that into account in our AMP code, but in this case, the best way for you to move forward is to not use Ray Train’s AMP and instead implement it yourself (or use a third party library like Accelerate, for which we have added support in Ray nightly with ray.train.huggingface.accelerate.AccelerateTrainer.

I am using Ray 2.1.0 and torch 2.0.0. And I get the similar training times like yours.
Thanks for your help.