Hi,
I try to use mixed precision training via ray train, but after using amp, the training time is much longer. But I try to train resnet50 with the same dataset without ray, amp can really get performance improvement. Could you give me some advices why this happen? Below is my code:
import argparse
from typing import Dict
from ray.air import session
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
from torchvision.models import resnet50
import ray.train as train
from ray.train.torch import TorchTrainer
from ray.air.config import ScalingConfig
import torchvision.transforms as transforms
import os
from filelock import FileLock
from torchvision.datasets import CIFAR10
def train_epoch(dataloader, model, loss_fn, optimizer,amp = False):
size = len(dataloader.dataset) // session.get_world_size()
model.train()
for batch, (X, y) in enumerate(dataloader):
# Compute prediction error
pred = model(X)
loss = loss_fn(pred, y)
# Backpropagation
optimizer.zero_grad()
if(amp):
train.torch.backward(loss)
else:
loss.backward()
optimizer.step()
if batch % 100 == 0:
loss, current = loss.item(), batch * len(X)
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
def validate_epoch(dataloader, model, loss_fn):
size = len(dataloader.dataset) // session.get_world_size()
num_batches = len(dataloader)
model.eval()
test_loss, correct = 0, 0
with torch.no_grad():
for X, y in dataloader:
pred = model(X)
test_loss += loss_fn(pred, y).item()
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
test_loss /= num_batches
correct /= size
print(
f"Test Error: \n "
f"Accuracy: {(100 * correct):>0.1f}%, "
f"Avg loss: {test_loss:>8f} \n"
)
return test_loss
def train_func(config: Dict):
batch_size = config["batch_size"]
lr = config["lr"]
epochs = config["epochs"]
amp = config["amp"]
worker_batch_size = batch_size // session.get_world_size()
# Create model.
# model = NeuralNetwork()
if(amp):
print('amp activated')
train.torch.accelerate(amp=True)
model = resnet50()
model = train.torch.prepare_model(model)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
if(amp):
optimizer = train.torch.prepare_optimizer(optimizer)
# Load in training and validation data.
transform_train = transforms.Compose(
[
transforms.Resize([224,224]),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
]
) # meanstd transformation
transform_test = transforms.Compose(
[
transforms.Resize([224,224]),
transforms.ToTensor(),
]
)
data_dir = config.get("data_dir", os.path.expanduser("~/data"))
os.makedirs(data_dir, exist_ok=True)
with FileLock(os.path.join(data_dir, ".ray.lock")):
train_dataset = CIFAR10(
root=data_dir, train=True, download=False, transform=transform_train
)
validation_dataset = CIFAR10(
root=data_dir, train=False, download=False, transform=transform_test
)
train_dataloader = DataLoader(train_dataset, batch_size=worker_batch_size,pin_memory=True)
test_dataloader = DataLoader(validation_dataset, batch_size=worker_batch_size,pin_memory=True)
train_dataloader = train.torch.prepare_data_loader(train_dataloader)
test_dataloader = train.torch.prepare_data_loader(test_dataloader)
worker_batch_size = config["batch_size"] // session.get_world_size()
for _ in range(epochs):
train_epoch(train_dataloader, model, loss_fn, optimizer,amp=amp)
loss = validate_epoch(test_dataloader, model, loss_fn)
session.report(dict(loss=loss))
def train_cifar10(num_workers=2, use_gpu=False,enable_amp=False):
trainer = TorchTrainer(
train_loop_per_worker=train_func,
train_loop_config={"lr": 1e-3, "batch_size": 128*num_workers, "epochs": 1,"amp":enable_amp},
scaling_config=ScalingConfig(num_workers=num_workers, use_gpu=use_gpu),
)
result = trainer.fit()
print(f"Last result: {result.metrics}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--address", required=False, type=str, help="the address to use for Ray"
)
parser.add_argument(
"--num-workers",
"-n",
type=int,
default=1,
help="Sets number of workers for training.",
)
parser.add_argument(
"--use-gpu", action="store_true", default=False, help="Enables GPU training"
)
parser.add_argument(
"--smoke-test",
action="store_true",
default=False,
help="Finish quickly for testing.",
)
parser.add_argument(
"--data-dir",
required=False,
type=str,
default="~/data",
help="Root directory for storing downloaded dataset.",
)
parser.add_argument(
"--amp",
action="store_true",
default=False,
help="Enable automatic mixed precision training.",
)
args, _ = parser.parse_known_args()
import ray
if args.smoke_test:
# 2 workers + 1 for trainer.
ray.init(num_cpus=3)
train_cifar10()
else:
ray.init(address=args.address)
train_cifar10(num_workers=args.num_workers, use_gpu=args.use_gpu, enable_amp=args.amp)
you could run by python torch_cifar10_example.py --use-gpu
or python torch_cifar10_example.py --use-gpu --amp