Hi! I am trying to use mixed precision training in part of the workers, for example, using two workers for training, but worker0 train with amp and worker1 use the default fp32 precision. Here are my code modified from example in Ray repository.
# file name: torch_cifar_train_example2.py
import argparse
from typing import Dict
from ray.air import session
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
from torchvision.models import resnet50
import ray.train as train
from ray.train.torch import TorchTrainer
from ray.air.config import ScalingConfig
import torchvision.transforms as transforms
from filelock import FileLock
from torchvision.datasets import CIFAR10
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
def train_epoch(dataloader, model, loss_fn, optimizer,amp = False):
size = len(dataloader.dataset) // session.get_world_size()
model.train()
for batch, (X, y) in enumerate(dataloader):
# print(f"inuput X: {X}") # input OK
# Compute prediction error
pred = model(X) # pred = non!!!!!!!!!!!!!!
check = int((pred != pred).sum())
if(check>0):
print("your data contains Nan")
else:
print("Your data does not contain Nan, it might be other problem")
# print(f"pred: {pred}")
# print(f"model: {model}")
loss = loss_fn(pred, y)
# Backpropagation
optimizer.zero_grad()
if(amp):
train.torch.backward(loss)
else:
loss.backward()
optimizer.step()
if batch % 100 == 0:
loss, current = loss.item(), batch * len(X)
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
def validate_epoch(dataloader, model, loss_fn):
size = len(dataloader.dataset) // session.get_world_size()
num_batches = len(dataloader)
model.eval()
test_loss, correct = 0, 0
with torch.no_grad():
for X, y in dataloader:
pred = model(X)
test_loss += loss_fn(pred, y).item()
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
test_loss /= num_batches
correct /= size
print(
f"Test Error: \n "
f"Accuracy: {(100 * correct):>0.1f}%, "
f"Avg loss: {test_loss:>8f} \n"
)
return test_loss
def train_func(config: Dict):
# pytorch use tf32 in default when training on A100
import torch.backends.cuda
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
batch_size = config["batch_size"]
lr = config["lr"]
epochs = config["epochs"]
# amp = config["amp"]
amp_num = config["amp_num"]
worker_batch_size = batch_size // session.get_world_size()
amp = False
if(session.get_world_rank() == amp_num):
amp = True
if(amp):
print('amp activated')
train.torch.accelerate(amp=True)
model = resnet50()
model = train.torch.prepare_model(model)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
if(amp):
optimizer = train.torch.prepare_optimizer(optimizer)
# Load in training and validation data.
transform_train = transforms.Compose(
[
transforms.Resize(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
]
) # meanstd transformation
transform_test = transforms.Compose(
[
transforms.Resize(224),
transforms.ToTensor(),
]
)
data_dir = config.get("data_dir", os.path.expanduser("~/data"))
os.makedirs(data_dir, exist_ok=True)
with FileLock(os.path.join(data_dir, ".ray.lock")):
train_dataset = CIFAR10(
root=data_dir, train=True, download=False, transform=transform_train
)
validation_dataset = CIFAR10(
root=data_dir, train=False, download=False, transform=transform_test
)
train_dataloader = DataLoader(train_dataset, batch_size=worker_batch_size,pin_memory=True)
test_dataloader = DataLoader(validation_dataset, batch_size=worker_batch_size,pin_memory=True)
train_dataloader = train.torch.prepare_data_loader(train_dataloader)
test_dataloader = train.torch.prepare_data_loader(test_dataloader)
worker_batch_size = config["batch_size"] // session.get_world_size()
for _ in range(epochs):
train_epoch(train_dataloader, model, loss_fn, optimizer,amp=amp)
loss = validate_epoch(test_dataloader, model, loss_fn)
session.report(dict(loss=loss))
def train_cifar10(num_workers=2, use_gpu=False,enable_amp=False):
trainer = TorchTrainer(
train_loop_per_worker=train_func,
train_loop_config={"lr": 1e-3, "batch_size": 256*num_workers, "epochs": 2,"amp_num":0}, #first gpu use amp
scaling_config=ScalingConfig(num_workers=num_workers, use_gpu=use_gpu),
)
result = trainer.fit()
print(f"Last result: {result.metrics}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--address", required=False, type=str, help="the address to use for Ray"
)
parser.add_argument(
"--num-workers",
"-n",
type=int,
default=1,
help="Sets number of workers for training.",
)
parser.add_argument(
"--use-gpu", action="store_true", default=False, help="Enables GPU training"
)
parser.add_argument(
"--smoke-test",
action="store_true",
default=False,
help="Finish quickly for testing.",
)
parser.add_argument(
"--data-dir",
required=False,
type=str,
default="~/data",
help="Root directory for storing downloaded dataset.",
)
parser.add_argument(
"--amp",
action="store_true",
default=False,
help="Enable automatic mixed precision training.",
)
args, _ = parser.parse_known_args()
import ray
if args.smoke_test:
# 2 workers + 1 for trainer.
ray.init(num_cpus=3)
train_cifar10()
else:
ray.init(address=args.address)
train_cifar10(num_workers=args.num_workers, use_gpu=args.use_gpu, enable_amp=args.amp)
You could run it by python torch_cifar_train_example2.py --use-gpu -n 2
. However, during the model forward i.e. model(X) outputs nan in the worker without amp. In this example, I set GPU0 to use amp, and the GPU1’s model outputs nan.
What should I do to avoid nan in the model output?