How to check training and validation distributed properly on the ray cluster

hey while training and validating, the progress for each worker is repeated.
not sure how it is working. it seems like workers executing it separately.
the loss and accuracy is differ in each workers
could you explain why its coming like this

image

import argparse
from typing import Dict

import torch
import ray.train as train
from ray.train.trainer import Trainer
from ray.train.callbacks import JsonLoggerCallback
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

# Download training data from open datasets.
training_data = datasets.FashionMNIST(
    root="~/data",
    train=True,
    download=True,
    transform=ToTensor(),
)

# Download test data from open datasets.
test_data = datasets.FashionMNIST(
    root="~/data",
    train=False,
    download=True,
    transform=ToTensor(),
)


# Define model
class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28 * 28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10),
            nn.ReLU(),
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits


def train_epoch(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset) // train.world_size()
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")


def validate_epoch(dataloader, model, loss_fn):
    print("validation starts")
    size = len(dataloader.dataset) // train.world_size()
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(
        f"Test Error: \n "
        f"Accuracy: {(100 * correct):>0.1f}%, "
        f"Avg loss: {test_loss:>8f} \n"
    )
    print("validation ends")
    return test_loss


def train_func(config: Dict):
    batch_size = config["batch_size"]
    lr = config["lr"]
    epochs = config["epochs"]

    worker_batch_size = batch_size // train.world_size()

    # Create data loaders.
    train_dataloader = DataLoader(training_data, batch_size=worker_batch_size)
    test_dataloader = DataLoader(test_data, batch_size=worker_batch_size)

    # train_dataloader = train.torch.prepare_data_loader(train_dataloader)
    # test_dataloader = train.torch.prepare_data_loader(test_dataloader)

    # Create model.
    model = NeuralNetwork()
    # model = train.torch.prepare_model(model)

    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=lr)

    loss_results = []

    for _ in range(epochs):
        train_epoch(train_dataloader, model, loss_fn, optimizer)
        loss = validate_epoch(test_dataloader, model, loss_fn)
        train.report(loss=loss)
        loss_results.append(loss)
    print(validate_epoch(test_dataloader, model, loss_fn))
    return loss_results


def train_fashion_mnist(num_workers=2, use_gpu=False):
    trainer = Trainer(backend="torch", num_workers=num_workers, use_gpu=use_gpu)
    trainer.start()
    result = trainer.run(
        train_func=train_func,
        config={"lr": 1e-3, "batch_size": 64, "epochs": 4},
        callbacks=[JsonLoggerCallback()],
    )
    trainer.shutdown()
    print(f"Loss results: {result}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--address", required=False, type=str, help="the address to use for Ray"
    )
    parser.add_argument(
        "--num-workers",
        "-n",
        type=int,
        default=2,
        help="Sets number of workers for training.",
    )
    parser.add_argument(
        "--use-gpu", action="store_true", default=False, help="Enables GPU training"
    )

    args, _ = parser.parse_known_args()

    import ray

    ray.init(address=args.address)
    train_fashion_mnist(num_workers=args.num_workers, use_gpu=args.use_gpu)

Hey @siva14guru- this is expected and is how data parallel training works.

Each worker holds a replica of the model. Your dataset is sharded across all the workers, and each worker trains on its own shard of data. After every batch, the gradients across all the model replicas are synchronized.

So the work is not being duplicated, but rather each worker is training on different data (hence why each worker has different loss and accuracy). This allows you to train on more data in less time since we are running in parallel.

For validation, if you dataset is small, you can also just do this only on worker 0 by moving your validation code under a if train.world_rank() == 0: conditional.

2 Likes

In the code snippet provided, I also see that model = train.torch.prepare_model(model) is commented out. In this case, it actually seems like each worker is training its own model (which starts with its own randomly seeded parameters) on the entire dataset, which would explain the divergence!

1 Like