Can I just use amp in part of the training actors?

Hi! I am trying to use mixed precision training in part of the workers, for example, using two workers for training, but worker0 train with amp and worker1 use the default fp32 precision. Here are my code modified from example in Ray repository.

# file name: torch_cifar_train_example2.py
import argparse
from typing import Dict
from ray.air import session

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
from torchvision.models import resnet50

import ray.train as train
from ray.train.torch import TorchTrainer
from ray.air.config import ScalingConfig

import torchvision.transforms as transforms
from filelock import FileLock
from torchvision.datasets import CIFAR10
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"


def train_epoch(dataloader, model, loss_fn, optimizer,amp = False):
    size = len(dataloader.dataset) // session.get_world_size()
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        # print(f"inuput X: {X}")  # input OK
        # Compute prediction error
        pred = model(X)   # pred = non!!!!!!!!!!!!!!
        check = int((pred != pred).sum())
        if(check>0):
            print("your data contains Nan")
        else:
            print("Your data does not contain Nan, it might be other problem")
        # print(f"pred: {pred}")
        # print(f"model: {model}")
        loss = loss_fn(pred, y)
       

        
        # Backpropagation
        optimizer.zero_grad()
        if(amp):
            train.torch.backward(loss)
        else:
            loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")


def validate_epoch(dataloader, model, loss_fn):
    size = len(dataloader.dataset) // session.get_world_size()
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(
        f"Test Error: \n "
        f"Accuracy: {(100 * correct):>0.1f}%, "
        f"Avg loss: {test_loss:>8f} \n"
    )
    return test_loss


def train_func(config: Dict):
    # pytorch use tf32 in default when training on A100
    import torch.backends.cuda

    torch.backends.cuda.matmul.allow_tf32 = False
    torch.backends.cudnn.allow_tf32 = False

    batch_size = config["batch_size"]
    lr = config["lr"]
    epochs = config["epochs"]
    # amp = config["amp"]
    amp_num = config["amp_num"]

    worker_batch_size = batch_size // session.get_world_size()

    
    amp = False
    if(session.get_world_rank() == amp_num):
        amp = True
    if(amp):
        print('amp activated')
        train.torch.accelerate(amp=True)


    model = resnet50()

    model = train.torch.prepare_model(model)

    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=lr)
    if(amp):
        optimizer = train.torch.prepare_optimizer(optimizer)


    # Load in training and validation data.
    transform_train = transforms.Compose(
        [   
            transforms.Resize(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
        ]
    )  # meanstd transformation

    transform_test = transforms.Compose(
        [
            transforms.Resize(224),
            transforms.ToTensor(),
        ]
    )

    data_dir = config.get("data_dir", os.path.expanduser("~/data"))
    os.makedirs(data_dir, exist_ok=True)
    with FileLock(os.path.join(data_dir, ".ray.lock")):
        train_dataset = CIFAR10(
            root=data_dir, train=True, download=False, transform=transform_train
        )
        validation_dataset = CIFAR10(
            root=data_dir, train=False, download=False, transform=transform_test
        )

    train_dataloader = DataLoader(train_dataset, batch_size=worker_batch_size,pin_memory=True)
    test_dataloader = DataLoader(validation_dataset, batch_size=worker_batch_size,pin_memory=True)

    train_dataloader = train.torch.prepare_data_loader(train_dataloader)
    test_dataloader = train.torch.prepare_data_loader(test_dataloader)
    worker_batch_size = config["batch_size"] // session.get_world_size()

    

   

    for _ in range(epochs):
        train_epoch(train_dataloader, model, loss_fn, optimizer,amp=amp)
        loss = validate_epoch(test_dataloader, model, loss_fn)
        session.report(dict(loss=loss))


def train_cifar10(num_workers=2, use_gpu=False,enable_amp=False):
    trainer = TorchTrainer(
        train_loop_per_worker=train_func,
        train_loop_config={"lr": 1e-3, "batch_size": 256*num_workers, "epochs": 2,"amp_num":0}, #first gpu use amp
        scaling_config=ScalingConfig(num_workers=num_workers, use_gpu=use_gpu),
    )
    result = trainer.fit()
    print(f"Last result: {result.metrics}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--address", required=False, type=str, help="the address to use for Ray"
    )
    parser.add_argument(
        "--num-workers",
        "-n",
        type=int,
        default=1,
        help="Sets number of workers for training.",
    )
    parser.add_argument(
        "--use-gpu", action="store_true", default=False, help="Enables GPU training"
    )
    parser.add_argument(
        "--smoke-test",
        action="store_true",
        default=False,
        help="Finish quickly for testing.",
    )
    parser.add_argument(
        "--data-dir",
        required=False,
        type=str,
        default="~/data",
        help="Root directory for storing downloaded dataset.",
    )
    parser.add_argument(
        "--amp",
        action="store_true",
        default=False,
        help="Enable automatic mixed precision training.",
    )

    args, _ = parser.parse_known_args()

    import ray

    if args.smoke_test:
        # 2 workers + 1 for trainer.
        ray.init(num_cpus=3)
        train_cifar10()
    else:
        ray.init(address=args.address)
        train_cifar10(num_workers=args.num_workers, use_gpu=args.use_gpu, enable_amp=args.amp)

You could run it by python torch_cifar_train_example2.py --use-gpu -n 2. However, during the model forward i.e. model(X) outputs nan in the worker without amp. In this example, I set GPU0 to use amp, and the GPU1’s model outputs nan.

What should I do to avoid nan in the model output?

Hey @HuangLianghong , as for Ray, we have no restriction on the behavior of each actor, and you can definitely use AMP in part of the actors. But for model training, I don’t think it’s safe to do so.

In AMP training, NaN gradients occasionally occur, and “scaler.step(optimizer)” of the AMP worker should safely skip that step. However, when you synchronize gradients between AMP workers and non-AMP workers in DDP, the global gradient may be Nan. The DDP worker won’t skip the optimizer step, thus might cause error.

Thank you @yunxuanx. I have tried to print the output after pred=model(x), in the first batch there is no Nan, but after that, the output become Nan. Is this due to the gradient synchronization?

It confused me, because I don’t know which line of code synchronizes gradients and does synchronization happen after each batch or after each iteration?

Hi @HuangLianghong , to check the gradient, I think you can print x.grad after loss.backward(). Theoretically, grad sync happens before optimizer.step() and after loss.backward().