Ray Train with Ray datasets (includes images) too slow

Hello! I am trying to use Ray Train and Ray datasets with imagenet dataset. I am wish to use the learnings here for large scale production use cases (Ours is a computer vision related team).

I converted Imagenet 2012 dataset to parquet (binary images and numeric labels) using ray (> 1.2M images in training). My findings were:
a) When I use this on a 8 GPU p3.16xlarge machine, downloading the whole dataset first , it takes around 3 hours for one epoch to complete. 1.5 hour spent in some map_block_split code and 1.5 hours for training in one epoch
b) A plain pytorch distributed data parallel working directly on the downloaded images takes around 17 mins on these 8 GPUs p3.16xlarge.
c) Suspecting disk IO related issues to be contributing factor, I moved to a larger g5 machine (8 gpu) with significantly faster NVME instance store storage. I was able to complete one epoch in 28 mins total, with 15 minutes spent in training.
d) Suspecting ray.get to be a contributing factor, I moved the dataset loading part to the training function in ray using pytorch dataset. I was able to complete overall run on p3.16xlarge under 20 mins and on g5 in around 5 mins.

Is there a way to speedup the distributed training using ray datasets? It will be great if you could point out any inefficiencies in my code. On a separate note it will be great to have the ray team provide a performant distributed imagenet training example on GPUs.

This is my training code with ray datasets:



# NOTE read this: https://docs.ray.io/en/latest/train/dl_guide.html#porting-code-to-ray-train

import numpy as np

import torch
from torchvision import transforms
# from ray.data.preprocessors import TorchVisionPreprocessor, Chain, LabelEncoder
from ray.data.preprocessors import Chain, LabelEncoder
from ray.data.datasource.partitioning import PartitionStyle
import ray

from ray import train
from ray.air import session, RunConfig, Checkpoint, CheckpointConfig
from ray.train.torch import TorchCheckpoint
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from ray.train.torch import TorchTrainer
from ray.air.config import ScalingConfig, DatasetConfig
from ray.air.integrations.mlflow import MLflowLoggerCallback
from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present
from ray.tune.logger import TBXLoggerCallback
import torchvision.models as models
import torchmetrics
import time
import pandas as pd
from io import BytesIO
from PIL import Image
from ray.air.util.tensor_extensions.pandas import TensorArray
from ray.data import ActorPoolStrategy
import sys
from ray.data.dataset import Dataset
from typing import Tuple, List, Dict
import os
from ray.data.preprocessor import Preprocessor
from pandas import DataFrame
import json


# NOTE should read files in parallel after figuring out paths
# read parallelism criteria: https://docs.ray.io/en/latest/data/creating-datasets.html#read-parallelism
# dataset = ray.data.read_images("s3://anonymous@air-example-data-2/imagenet-sample-images")
from ray.data.datasource.partitioning import Partitioning



# From https://pytorch.org/vision/main/_modules/torchvision/datasets/folder.html#ImageFolder
# This code is used to create the sysnset.json
def find_classes(directory: str) -> Tuple[List[str], Dict[str, int]]:
    """Finds the class folders in a dataset.

    See :class:`DatasetFolder` for details.
    """
    classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())
    if not classes:
        raise FileNotFoundError(f"Couldn't find any class folder in {directory}.")

    class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
    return classes, class_to_idx



f_synset = open("./synset.json", "r")
synset_dict = json.load(f_synset)
classes = synset_dict["classes"]
class_to_idx = synset_dict["class_to_idx"]
f_synset.close()


train_root = "/home/ec2-user/RAY_GIT/Imagenet/imagenet/train"
#train_root = "s3://a9vs-photon-model-training-exp/imagenet/imagenet/train"
train_partitioning = Partitioning(PartitionStyle.DIRECTORY, field_names=["category"], base_dir=train_root)
train_dataset = ray.data.read_binary_files(paths=train_root, partitioning=train_partitioning, parallelism=20000)
# train_dataset_pipe = ray.data.read_binary_files(paths=train_root, partitioning=train_partitioning).window(bytes_per_window=40e9)
print(train_dataset)
# print(train_dataset.schema())
total_train_dataset_len = train_dataset.count()
print(train_dataset.count())
# print(train_dataset.default_batch_format())
# print(train_dataset.fully_executed().size_bytes())

val_root = "/home/ec2-user/RAY_GIT/Imagenet/imagenet/val"
#val_root = "s3://a9vs-photon-model-training-exp/imagenet/imagenet/val"
val_partitioning = Partitioning(PartitionStyle.DIRECTORY, field_names=["category"], base_dir=val_root)
val_dataset = ray.data.read_binary_files(paths=val_root, partitioning=val_partitioning)
print(val_dataset)



# https://docs.ray.io/en/latest/ray-air/preprocessors.html#implementing-custom-preprocessors
# class CustomPreprocessor(BatchMapper):
class CustomPreprocessor(Preprocessor):
    _is_fittable = False

    def __init__(
            self,
            classes,
            class_to_idx):
        self.classes = classes
        self.class_to_idx = class_to_idx

    """
    def _fit(self, dataset: Dataset) -> Preprocessor:
        self.stats_ = dataset.aggregate(Max("value"))
    """

    def _transform_pandas(self, item_list: DataFrame) -> DataFrame:
        # we got class 'pandas.core.frame.DataFrame'
        # print(type(item_list))
        # print(item_list)
        byte_items = item_list['bytes']
        categories = item_list['category']
        # print(len(byte_items), len(categories))
        preprocess = transforms.Compose([
            # transforms.Resize(256),
            # transforms.CenterCrop(224),
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])

        images = [Image.open(BytesIO(byte_item)).convert('RGB') for byte_item in byte_items]
        images = [preprocess(image) for image in images]
        images = [np.array(image) for image in images]
        categories = [self.class_to_idx[category] for category in categories]

        return pd.DataFrame({"image": TensorArray(images), "category": categories})


custom_preprocessor = CustomPreprocessor(classes, class_to_idx)



# There needs to be a validate function too


# def train_epoch(train_dataset_shard, model, criterion, optimizer, **kwargs):
def train_epoch(epoch, train_dataset_shard, model, criterion, optimizer, **kwargs):
    """
    OBSOLETE
    batch_time = AverageMeter('Time', ':6.3f')
    data_time = AverageMeter('Data', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    progress = ProgressMeter(
        len(train_loader),
        [batch_time, data_time, losses, top1, top5],
        prefix="Epoch: [{}]".format(epoch))
    """
    rank = session.get_world_rank()
    batch_size = kwargs['batch_size']
    mape = kwargs['mape']
    mean_validation_loss = kwargs['mean_validation_loss']
    acc1 = kwargs['acc1']
    # acc5 = kwargs['acc5']
    total_train_dataset_len = kwargs['total_train_dataset_len']

    # switch to train mode
    model.train()

    end = time.time()

    # NOTE: when you load from a dataloader, use the following
    # size = len(dataloader.dataset) // session.get_world_size()
    # The following does not work:  ValueError("Cannot count a pipeline of infinite length.")
    # ds_shard_size = train_dataset_shard.count()
    ds_shard_size = total_train_dataset_len // session.get_world_size()

    train_dataset_batches = train_dataset_shard.iter_torch_batches(
        device=train.torch.get_device(),
        # Works for ray3.0 only i guess, check if this is needed as per https://github.com/ray-project/ray/pull/31692/files
        prefetch_blocks=3,
        batch_size=batch_size,  # default is 256, use None to set entire block as batch
        local_shuffle_buffer_size=batch_size * 2  # should be greater than equal to batch size
    )

    start = time.perf_counter()
    batches_read, bytes_read = 0, 0
    batch_start = time.perf_counter()
    batch_delays = []
    batch_sizes = []
    # print("train_dataset_shard_batches: ", train_dataset_shard_batches)
    print("train_dataset_batches: ", train_dataset_batches)
    for i, batch in enumerate(train_dataset_batches):
        batch_delay = time.perf_counter() - batch_start
        batch_delays.append(batch_delay)
        batches_read += 1


        # get the images and classes
        images, target = batch["image"], batch["category"]
        bytes_read += (images.size(0) * (224 * 224 * 3 + 64))

        batch_sizes.append(images.size(0))

        # TODO
        # measure data loading time
        # data_time.update(time.time() - end)

        """
        OBSOLETE
        # move data to the same device as model
        images = images.to(device, non_blocking=True)
        target = target.to(device, non_blocking=True)
        """

        # compute output
        output = model(images)
        loss = criterion(output, target)

        batch_time = time.time() - end
        end = time.time()
        if i and i % 20 == 0:
            loss_item, current = loss.item(), i * images.size(0)
            print(f"loss: {loss_item:>7f}  [{current:>5d}/{ds_shard_size:>5d}], batch_time: {batch_time}")
            # Debugging for dataset and datasetpipeline https://docs.ray.io/en/latest/data/performance-tips.html#debugging-statistics
            print(train_dataset_shard.stats())


        # compute gradient and do SGD step
        optimizer.zero_grad()
        # loss.backward()
        # AMP
        train.torch.backward(loss)
        optimizer.step()


        batch_start = time.perf_counter()

    delta = time.perf_counter() - start
    print("Epoch: {}, Training Stats for rank: ".format(epoch), rank)
    total_time_reading = np.sum(batch_delays)
    print("Time to read all data in epoch", total_time_reading, "seconds")
    print(
        "P50/P95/Max batch delay (s)",
        np.quantile(batch_delays, 0.5),
        np.quantile(batch_delays, 0.95),
        np.max(batch_delays),
    )
    print("Num batches read", batches_read)
    print("Average number of rows in each batch: ", np.mean(batch_sizes))
    print("Num bytes read", round(bytes_read / (1024 * 1024), 2), "MiB")
    print(
        "Mean Dataset throughput", round(bytes_read / (1024 * 1024) / total_time_reading, 2), "MiB/s"
    )


def validate_epoch(epoch, validation_dataset_shard, model, criterion, **kwargs):
    batch_size = kwargs['batch_size']
    mape = kwargs['mape']
    mean_validation_loss = kwargs['mean_validation_loss']
    acc1 = kwargs['acc1']

    # acc5 = kwargs['acc5']

    def run_validate(validation_dataset_shard, base_progress=0):
        rank = session.get_world_rank()
        with torch.no_grad():
            end = time.time()

            validation_dataset_batches = Dataset.iter_torch_batches(
                validation_dataset_shard,
                prefetch_blocks=10,
                batch_size=batch_size,  # default is 256, use None to set entire block as batch
                # drop_last=drop_last,
                local_shuffle_buffer_size=batch_size * 2,  # should be greater than equal to batch size
                device=train.torch.get_device()
                # Works for ray3.0 only i guess, check if this is needed as per https://github.com/ray-project/ray/pull/31692/files
                # local_shuffle_seed=local_shuffle_seed,
            )
            running_loss = 0.
            start = time.perf_counter()
            batches_read, bytes_read = 0, 0
            batch_start = time.perf_counter()
            batch_delays = []
            batch_sizes = []
            for i, batch in enumerate(validation_dataset_batches):
                batch_delay = time.perf_counter() - batch_start
                batch_delays.append(batch_delay)
                batches_read += 1
                # get the images and classes
                images, target = batch["image"], batch["category"]
                i = base_progress + i
                bytes_read += (images.size(0) * (224 * 224 * 3 + 64))
                batch_sizes.append(images.size(0))


                # compute output
                output = model(images)
                loss = criterion(output, target)

                # measure accuracy and record loss
                # save loss in aggregator
                running_loss += loss.item()
                mean_validation_loss(loss)
                mape(torch.argmax(output, dim=1), target)
                acc1(torch.argmax(output, dim=1), target)
                # acc5(torch.argmax(output, dim=1), target)



                # measure elapsed time
                batch_time = time.time() - end
                end = time.time()


                batch_start = time.perf_counter()

            print("Epoch: {}, Validation Stats for rank: ".format(epoch), rank)
            total_time_reading = np.sum(batch_delays)
            print("Time to read all data in epoch", total_time_reading, "seconds")
            print(
                "P50/P95/Max batch delay (s)",
                np.quantile(batch_delays, 0.5),
                np.quantile(batch_delays, 0.95),
                np.max(batch_delays),
            )
            print("Num batches read", batches_read)
            print("Average number of rows in each batch: ", np.mean(batch_sizes))
            print("Num bytes read", round(bytes_read / (1024 * 1024), 2), "MiB")
            print(
                "Mean Dataset throughput", round(bytes_read / (1024 * 1024) / total_time_reading, 2), "MiB/s"
            )

            valid_loss_collected = running_loss / i  # i == num_batches
            return valid_loss_collected



    # switch to evaluate mode
    model.eval()

    validation_loss_collected = run_validate(validation_dataset_shard)
    # collect all metrics
    # use .item() to obtain a value that can be reported
    mape_collected = mape.compute().item()
    mean_validation_loss_collected = mean_validation_loss.compute().item()
    acc1_collected = acc1.compute().item()
    # acc5_collected = acc5.compute().item()
    # TODO
    acc5_collected = None

    return mape_collected, validation_loss_collected, mean_validation_loss_collected, acc1_collected, acc5_collected



# We import Ray Train and Ray AIR Session:
# We use a config dict to configure some hyperparameters (this is not strictly needed but good practice, especially if you want to o hyperparameter tuning later):
# https://docs.ray.io/en/latest/ray-air/package-ref.html#ray.train.data_parallel_trainer.DataParallelTrainer
# Useful example code:
# 1) https://docs.ray.io/en/latest/ray-air/examples/torch_image_example.html
# 2) https://docs.ray.io/en/latest/train/examples/pytorch/torch_fashion_mnist_example.html
# 3) https://docs.ray.io/en/latest/ray-air/check-ingest.html#disabling-preprocessor-transforms
# 4) https://docs.ray.io/en/latest/train/dl_guide.html#distributed-data-ingest-with-ray-datasets
# 5) https://docs.ray.io/en/latest/ray-air/examples/convert_existing_pytorch_code_to_ray_air.html
def train_loop_per_worker(config: dict):
    # AMP
    train.torch.accelerate(amp=True)
    device = train.torch.get_device()
    rank = session.get_world_rank()
    # Reproducibility as per https://docs.ray.io/en/latest/train/dl_guide.html#reproducibility
    # train.torch.enable_reproducibility()

    # Parameters
    batch_size = config["batch_size"]
    lr = config["lr"]
    momentum = config["momentum"]
    weight_decay = config["weight_decay"]
    epochs = config["num_epochs"]
    total_train_dataset_len = config["total_train_dataset_len"]
    # We dynamically adjust the worker batch size according to the number of workers:
    batch_size_per_worker = batch_size // session.get_world_size()

    print("=> creating model '{}'".format("resnet18"))
    model = models.__dict__["resnet18"]()
    #  train.torch.prepare_model() also automatically takes care of setting up devices (e.g. GPU training) - so we can get rid of those lines in our current code!
    model = train.torch.prepare_model(model)
    # model.cuda(args.gpu) <== not needed

    # define loss function (criterion), optimizer, and learning rate scheduler
    # TODO hopefully ray train takes care of setting up the device properly
    criterion = nn.CrossEntropyLoss()
    # criterion = nn.CrossEntropyLoss().to(device)

    optimizer = torch.optim.SGD(model.parameters(), lr,
                                momentum=momentum,
                                weight_decay=weight_decay)
    # AMP
    optimizer = train.torch.prepare_optimizer(optimizer)

    # Get worker shards over data
    # Datasets can be accessed in your train_func via ``get_dataset_shard``.
    # This has become a pipeline due to streaming in the dataset config
    train_dataset_shard = session.get_dataset_shard("train")
    # This is a dataset
    validation_dataset_shard = session.get_dataset_shard("validation")
    print("train_dataset_shard: ", train_dataset_shard)
    print("validation_dataset_shard: ", validation_dataset_shard)

    """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
    scheduler = StepLR(optimizer, step_size=30, gamma=0.1)

    # NOTE: metrics
    # measure accuracy and record loss
    mape = torchmetrics.MeanAbsolutePercentageError().to(device)
    # for averaging loss
    mean_validation_loss = torchmetrics.MeanMetric().to(device)
    acc1 = torchmetrics.Accuracy(task="multiclass", num_classes=1000, top_k=1).to(device)
    # TODO
    acc5 = None
    # acc5 = torchmetrics.Accuracy(task="multiclass", num_classes=1000, top_k=5).to(device)



    start_epoch = 0

    checkpoint = session.get_checkpoint()
    if checkpoint:
        # assume that we have run the session.report() example
        # and successfully save some model weights
        checkpoint_dict = checkpoint.to_dict()
        model.load_state_dict(checkpoint_dict.get("model_weights"))
        start_epoch = checkpoint_dict.get("epoch", -1) + 1

    epochs_read = 0
    # we will iterate over epoch delimited pipeline segments, streaming will create windows of blocks whose order will change within a worker, data is nicely
    # split between workers for one epoch, though all workers see all data over multiple epochs. This seems to be constrast to pytorch DataParallel with shuffle False where one worker see the same data for every epoch
    # reference : https://github.com/ray-project/ray/issues/18287
    # for epoch, train_dataset_shard_batches in enumerate(train_dataset_shard.iter_epochs()):
    total_time_in_training = 0.
    for epoch in range(start_epoch, epochs):
        print(" Epoch ", epoch)
        epochs_read += 1



        # train for one epoch
        # train_epoch(train_dataset_shard, model, criterion, optimizer, batch_size=batch_size, mape=mape,
        epoch_start_time = time.time()
        train_epoch(epoch, train_dataset_shard, model, criterion, optimizer, batch_size=batch_size_per_worker, mape=mape,
                    mean_validation_loss=mean_validation_loss, acc1=acc1, acc5=acc5,
                    total_train_dataset_len=total_train_dataset_len)
        training_epoch_time = time.time() - epoch_start_time
        total_time_in_training += training_epoch_time
        print("Training time in epoch for rank {} = {}".format(rank, training_epoch_time))

        validation_start_time = time.time()
        mape_collected, validation_loss_collected, mean_validation_loss_collected, acc1_collected, acc5_collected = validate_epoch(
            epoch, validation_dataset_shard, model, criterion, batch_size=batch_size, mape=mape,
            mean_validation_loss=mean_validation_loss, acc1=acc1, acc5=acc5)
        print("Validation time in epoch for rank {} = {}".format(rank, time.time() - validation_start_time))

        # checkpoint for fault tolerance
        state_dict = model.state_dict()
        consume_prefix_in_state_dict_if_present(state_dict, "module.")
        checkpoint = Checkpoint.from_dict(
            # TODO, where is loss coming from, perhaps from validation
            dict(epoch=epoch, model_weights=state_dict, mean_validation_loss_collected=mean_validation_loss_collected)
        )
        # Report session and checkpoint
        session.report(
            {
                "mape_collected": mape_collected,
                "validation_loss_collected": validation_loss_collected,
                "mean_validation_loss_collected": mean_validation_loss_collected,
                "acc1_collected": acc1_collected,
                # "acc5_collected": acc5_collected
            },
            checkpoint=checkpoint
        )

        # reset for next epoch
        mape.reset()
        mean_validation_loss.reset()
        acc1.reset()
        # acc5.reset()



        scheduler.step()


    print("******************* Total time in training: ", total_time_in_training)
    # NOTE https://docs.ray.io/en/latest/train/dl_guide.html#configuring-checkpoints


# Keep the 2 checkpoints with the smallest "mean_validation_loss_collected" value.
checkpoint_config = CheckpointConfig(
    num_to_keep=2, checkpoint_score_attribute="mean_validation_loss_collected", checkpoint_score_order="min"
)

# NOTE: we just use this preprocesor to convert categories to integers
dataset_preprocessor = LabelEncoder(label_column="category")
# NOTE refer to https://docs.ray.io/en/latest/train/dl_guide.html#configuring-training
# https://docs.ray.io/en/latest/train/api.html#pytorch
trainer = TorchTrainer(
    train_loop_per_worker=train_loop_per_worker,
    train_loop_config={"num_epochs": 1, "batch_size": 256, "lr": 0.1, "momentum": 0.9, "weight_decay": 1e-4,
                       "total_train_dataset_len": total_train_dataset_len},
    # https://docs.ray.io/en/latest/train/dl_guide.html#example-logging-to-mlflow-and-tensorboard

    datasets={"train": train_dataset, "validation": val_dataset},
    # To scale your training script, create a Ray Cluster and increase the number of workers. If your cluster contains GPUs, add "use_gpu": True to your scaling config.
    # scaling_config=ScalingConfig(num_workers=8, use_gpu=True, trainer_resources={"CPU": 20}, resources_per_worker={"CPU": 2, "GPU": 1}),
    scaling_config=ScalingConfig(num_workers=8, use_gpu=True, resources_per_worker={"CPU": 2, "GPU": 1}),
    run_config=RunConfig(
        checkpoint_config=checkpoint_config,
        callbacks=[
            MLflowLoggerCallback(experiment_name="train_experiment"),
            TBXLoggerCallback(),
        ],
        # https://docs.ray.io/en/latest/ray-air/package-ref.html#ray.air.config.RunConfig
        verbose=3,
        log_to_file=True
    ),
    dataset_config={
        "train": DatasetConfig(fit=True, split=True, global_shuffle=False),
    },

    # preprocessor = dataset_preprocessor,
    # preprocessor = preprocessor,
    preprocessor=custom_preprocessor
    # if you wish to resume from a previously saved checkpoint
    # resume_from_checkpoint=result.checkpoint,
)

# NOTE: interpret training results https://docs.ray.io/en/latest/ray-air/trainer.html#how-to-interpret-training-results
result = trainer.fit()

This is my code with torch loader:


import numpy as np

import torch
from torchvision import transforms
# from ray.data.preprocessors import TorchVisionPreprocessor, Chain, LabelEncoder
from ray.data.preprocessors import Chain, LabelEncoder
from ray.data.datasource.partitioning import PartitionStyle
import ray

from ray import train
from ray.air import session, RunConfig, Checkpoint, CheckpointConfig
from ray.train.torch import TorchCheckpoint
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from ray.train.torch import TorchTrainer
from ray.air.config import ScalingConfig, DatasetConfig
from ray.air.integrations.mlflow import MLflowLoggerCallback
from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present
from ray.tune.logger import TBXLoggerCallback
import torchvision.models as models
import torchmetrics
import time
import pandas as pd
from io import BytesIO
from PIL import Image
from ray.air.util.tensor_extensions.pandas import TensorArray
from ray.data import ActorPoolStrategy
import sys
from ray.data.dataset import Dataset
from typing import Tuple, List, Dict
import os
from ray.data.preprocessor import Preprocessor
from pandas import DataFrame
import json
import torchvision.datasets as datasets

from ray.data.datasource.partitioning import Partitioning




def train_epoch(epoch, train_dataloader, model, criterion, optimizer, **kwargs):
    size = len(train_dataloader.dataset) // session.get_world_size()  # Divide by word size
    rank = session.get_world_rank()
    batch_size = kwargs['batch_size']
    mape = kwargs['mape']
    mean_validation_loss = kwargs['mean_validation_loss']
    acc1 = kwargs['acc1']
    # acc5 = kwargs['acc5']
    #total_train_dataset_len = kwargs['total_train_dataset_len']

    # switch to train mode
    model.train()

    end = time.time()



    start = time.perf_counter()
    batches_read, bytes_read = 0, 0
    batch_start = time.perf_counter()
    batch_delays = []
    batch_sizes = []

    for i, (images, target) in enumerate(train_dataloader):
        batch_delay = time.perf_counter() - batch_start
        batch_delays.append(batch_delay)
        batches_read += 1

        # get the images and classes
        #images, target = batch["image"], batch["category"]
        bytes_read += (images.size(0) * (224 * 224 * 3 + 64))

        batch_sizes.append(images.size(0))


        # compute output
        output = model(images)
        loss = criterion(output, target)

        batch_time = time.time() - end
        end = time.time()
        if i and i % 20 == 0:
            loss_item, current = loss.item(), i * images.size(0)
            print(f"loss: {loss_item:>7f}  [{current:>5d}/{size:>5d}], batch_time: {batch_time}")
            # Debugging for dataset and datasetpipeline https://docs.ray.io/en/latest/data/performance-tips.html#debugging-statistics
            #print(train_dataset_shard.stats())


        # compute gradient and do SGD step
        optimizer.zero_grad()
        # loss.backward()
        # AMP
        train.torch.backward(loss)
        optimizer.step()

        batch_start = time.perf_counter()

    delta = time.perf_counter() - start
    print("Epoch: {}, Training Stats for rank: ".format(epoch), rank)
    total_time_reading = np.sum(batch_delays)
    print("Time to read all data in epoch", total_time_reading, "seconds")
    print(
        "P50/P95/Max batch delay (s)",
        np.quantile(batch_delays, 0.5),
        np.quantile(batch_delays, 0.95),
        np.max(batch_delays),
    )
    print("Num batches read", batches_read)
    print("Average number of rows in each batch: ", np.mean(batch_sizes))
    print("Num bytes read", round(bytes_read / (1024 * 1024), 2), "MiB")
    print(
        "Mean Dataset throughput", round(bytes_read / (1024 * 1024) / total_time_reading, 2), "MiB/s"
    )


#def validate_epoch(epoch, validation_dataset_shard, model, criterion, **kwargs):
def validate_epoch(epoch, val_dataloader, model, criterion, **kwargs):
    batch_size = kwargs['batch_size']
    mape = kwargs['mape']
    mean_validation_loss = kwargs['mean_validation_loss']
    acc1 = kwargs['acc1']

    # acc5 = kwargs['acc5']

    #def run_validate(validation_dataset_shard, base_progress=0):
    def run_validate(val_dataloader, base_progress=0):
        rank = session.get_world_rank()
        with torch.no_grad():
            end = time.time()

            running_loss = 0.
            start = time.perf_counter()
            batches_read, bytes_read = 0, 0
            batch_start = time.perf_counter()
            batch_delays = []
            batch_sizes = []

            #for i, batch in enumerate(validation_dataset_batches):
            for i, (images, target) in enumerate(val_dataloader):
                batch_delay = time.perf_counter() - batch_start
                batch_delays.append(batch_delay)
                batches_read += 1
                # get the images and classes
                #images, target = batch["image"], batch["category"]
                i = base_progress + i
                bytes_read += (images.size(0) * (224 * 224 * 3 + 64))
                batch_sizes.append(images.size(0))

                # compute output
                output = model(images)
                loss = criterion(output, target)

                # measure accuracy and record loss
                # save loss in aggregator
                running_loss += loss.item()
                mean_validation_loss(loss)
                mape(torch.argmax(output, dim=1), target)
                acc1(torch.argmax(output, dim=1), target)
                # acc5(torch.argmax(output, dim=1), target)


                # measure elapsed time
                batch_time = time.time() - end
                end = time.time()

                batch_start = time.perf_counter()

            print("Epoch: {}, Validation Stats for rank: ".format(epoch), rank)
            total_time_reading = np.sum(batch_delays)
            print("Time to read all data in epoch", total_time_reading, "seconds")
            print(
                "P50/P95/Max batch delay (s)",
                np.quantile(batch_delays, 0.5),
                np.quantile(batch_delays, 0.95),
                np.max(batch_delays),
            )
            print("Num batches read", batches_read)
            print("Average number of rows in each batch: ", np.mean(batch_sizes))
            print("Num bytes read", round(bytes_read / (1024 * 1024), 2), "MiB")
            print(
                "Mean Dataset throughput", round(bytes_read / (1024 * 1024) / total_time_reading, 2), "MiB/s"
            )

            valid_loss_collected = running_loss / i  # i == num_batches
            return valid_loss_collected


    # switch to evaluate mode
    model.eval()

    #validation_loss_collected = run_validate(validation_dataset_shard)
    validation_loss_collected = run_validate(val_dataloader)
    # collect all metrics
    # use .item() to obtain a value that can be reported
    mape_collected = mape.compute().item()
    mean_validation_loss_collected = mean_validation_loss.compute().item()
    acc1_collected = acc1.compute().item()
    # acc5_collected = acc5.compute().item()
    # TODO
    acc5_collected = None

    return mape_collected, validation_loss_collected, mean_validation_loss_collected, acc1_collected, acc5_collected


def load_data():
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    train_dir = "/home/ec2-user/RAY_GIT/Imagenet/imagenet/train"
    train_dataset = datasets.ImageFolder(
        train_dir,
        transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))

    val_dir = "/home/ec2-user/RAY_GIT/Imagenet/imagenet/val"
    val_dataset = datasets.ImageFolder(
        val_dir,
        transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ]))
    return train_dataset, val_dataset

def train_loop_per_worker(config: dict):
    # AMP
    train.torch.accelerate(amp=True)
    device = train.torch.get_device()
    rank = session.get_world_rank()

    # Parameters
    batch_size = config["batch_size"]
    lr = config["lr"]
    momentum = config["momentum"]
    weight_decay = config["weight_decay"]
    epochs = config["num_epochs"]
    #total_train_dataset_len = config["total_train_dataset_len"]
    # We dynamically adjust the worker batch size according to the number of workers:
    batch_size_per_worker = batch_size // session.get_world_size()

    print("=> creating model '{}'".format("resnet18"))
    model = models.__dict__["resnet18"]()
    #  train.torch.prepare_model() also automatically takes care of setting up devices (e.g. GPU training) - so we can get rid of those lines in our current code!
    model = train.torch.prepare_model(model)
    # model.cuda(args.gpu) <== not needed

    # define loss function (criterion), optimizer, and learning rate scheduler
    # TODO hopefully ray train takes care of setting up the device properly
    criterion = nn.CrossEntropyLoss()
    # criterion = nn.CrossEntropyLoss().to(device)

    optimizer = torch.optim.SGD(model.parameters(), lr,
                                momentum=momentum,
                                weight_decay=weight_decay)
    # AMP
    optimizer = train.torch.prepare_optimizer(optimizer)

    # Create data loaders.
    train_dataset, val_dataset = load_data()
    train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, num_replicas=session.get_world_size(), rank=rank)
    val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, shuffle=False, drop_last=True)
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=batch_size_per_worker, shuffle=(train_sampler is None),
        num_workers=session.get_world_size(), pin_memory=True, sampler=train_sampler)
    val_loader = torch.utils.data.DataLoader(
        val_dataset, batch_size=batch_size_per_worker, shuffle=False,
        num_workers=session.get_world_size(), pin_memory=True, sampler=val_sampler)

    train_dataloader = train.torch.prepare_data_loader(train_loader)
    val_dataloader = train.torch.prepare_data_loader(val_loader)


    """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
    scheduler = StepLR(optimizer, step_size=30, gamma=0.1)

    # NOTE: metrics
    # measure accuracy and record loss
    mape = torchmetrics.MeanAbsolutePercentageError().to(device)
    # for averaging loss
    mean_validation_loss = torchmetrics.MeanMetric().to(device)
    acc1 = torchmetrics.Accuracy(task="multiclass", num_classes=1000, top_k=1).to(device)
    # TODO
    acc5 = None
    # acc5 = torchmetrics.Accuracy(task="multiclass", num_classes=1000, top_k=5).to(device)



    start_epoch = 0

    checkpoint = session.get_checkpoint()
    if checkpoint:
        # assume that we have run the session.report() example
        # and successfully save some model weights
        checkpoint_dict = checkpoint.to_dict()
        model.load_state_dict(checkpoint_dict.get("model_weights"))
        start_epoch = checkpoint_dict.get("epoch", -1) + 1

    epochs_read = 0
    # we will iterate over epoch delimited pipeline segments, streaming will create windows of blocks whose order will change within a worker, data is nicely
    # split between workers for one epoch, though all workers see all data over multiple epochs. This seems to be constrast to pytorch DataParallel with shuffle False where one worker see the same data for every epoch
    # reference : https://github.com/ray-project/ray/issues/18287
    # for epoch, train_dataset_shard_batches in enumerate(train_dataset_shard.iter_epochs()):
    total_time_in_training = 0.
    for epoch in range(start_epoch, epochs):
        print(" Epoch ", epoch)
        epochs_read += 1


        # train for one epoch
        # train_epoch(train_dataset_shard, model, criterion, optimizer, batch_size=batch_size, mape=mape,
        epoch_start_time = time.time()
        """
        train_epoch(epoch, train_dataset_shard, model, criterion, optimizer, batch_size=batch_size_per_worker, mape=mape,
                    mean_validation_loss=mean_validation_loss, acc1=acc1, acc5=acc5,
                    total_train_dataset_len=total_train_dataset_len)
        """
        train_epoch(epoch, train_dataloader, model, criterion, optimizer, batch_size=batch_size_per_worker, mape=mape,
                    mean_validation_loss=mean_validation_loss, acc1=acc1, acc5=acc5)
        training_epoch_time = time.time() - epoch_start_time
        total_time_in_training += training_epoch_time
        print("Training time in epoch for rank {} = {}".format(rank, training_epoch_time))

        validation_start_time = time.time()
        mape_collected, validation_loss_collected, mean_validation_loss_collected, acc1_collected, acc5_collected = validate_epoch(
            epoch, val_dataloader, model, criterion, batch_size=batch_size, mape=mape,
            mean_validation_loss=mean_validation_loss, acc1=acc1, acc5=acc5)
        """
        mape_collected, validation_loss_collected, mean_validation_loss_collected, acc1_collected, acc5_collected = validate_epoch(
            epoch, validation_dataset_shard, model, criterion, batch_size=batch_size, mape=mape,
            mean_validation_loss=mean_validation_loss, acc1=acc1, acc5=acc5)
        """
        print("Validation time in epoch for rank {} = {}".format(rank, time.time() - validation_start_time))

        # checkpoint for fault tolerance
        state_dict = model.state_dict()
        consume_prefix_in_state_dict_if_present(state_dict, "module.")
        checkpoint = Checkpoint.from_dict(
            # TODO, where is loss coming from, perhaps from validation
            dict(epoch=epoch, model_weights=state_dict, mean_validation_loss_collected=mean_validation_loss_collected)
        )
        # Report session and checkpoint
        session.report(
            {
                "mape_collected": mape_collected,
                "validation_loss_collected": validation_loss_collected,
                "mean_validation_loss_collected": mean_validation_loss_collected,
                "acc1_collected": acc1_collected,
                # "acc5_collected": acc5_collected
            },
            checkpoint=checkpoint
        )

        # reset for next epoch
        mape.reset()
        mean_validation_loss.reset()
        acc1.reset()
        # acc5.reset()


        scheduler.step()

    print("******************* Total time in training: ", total_time_in_training)
    # NOTE https://docs.ray.io/en/latest/train/dl_guide.html#configuring-checkpoints


# Keep the 2 checkpoints with the smallest "mean_validation_loss_collected" value.
checkpoint_config = CheckpointConfig(
    num_to_keep=2, checkpoint_score_attribute="mean_validation_loss_collected", checkpoint_score_order="min"
)

# NOTE: we just use this preprocesor to convert categories to integers
dataset_preprocessor = LabelEncoder(label_column="category")
# NOTE refer to https://docs.ray.io/en/latest/train/dl_guide.html#configuring-training
# https://docs.ray.io/en/latest/train/api.html#pytorch
trainer = TorchTrainer(
    train_loop_per_worker=train_loop_per_worker,
    train_loop_config={"num_epochs": 1, "batch_size": 256, "lr": 0.1, "momentum": 0.9, "weight_decay": 1e-4},
    # https://docs.ray.io/en/latest/train/dl_guide.html#example-logging-to-mlflow-and-tensorboard

    #datasets={"train": train_dataset, "validation": val_dataset},
    # To scale your training script, create a Ray Cluster and increase the number of workers. If your cluster contains GPUs, add "use_gpu": True to your scaling config.
    # scaling_config=ScalingConfig(num_workers=8, use_gpu=True, trainer_resources={"CPU": 20}, resources_per_worker={"CPU": 2, "GPU": 1}),
    scaling_config=ScalingConfig(num_workers=8, use_gpu=True, resources_per_worker={"CPU": 2, "GPU": 1}),
    run_config=RunConfig(
        checkpoint_config=checkpoint_config,
        callbacks=[
            MLflowLoggerCallback(experiment_name="train_experiment"),
            TBXLoggerCallback(),
        ],
        # https://docs.ray.io/en/latest/ray-air/package-ref.html#ray.air.config.RunConfig
        verbose=3,
        log_to_file=True
    ),

)

# NOTE: interpret training results https://docs.ray.io/en/latest/ray-air/trainer.html#how-to-interpret-training-results
result = trainer.fit()


Here are some dataset reading stats captured while training for p3.16x large when using ray datasets

(RayTrainWorker pid=30233) loss: 6.636719  [160000/160145], batch_time: 1.9915745258331299
(RayTrainWorker pid=30226) Stage 1 read->map_batches: 20000/20000 blocks executed in 5191.45s
(RayTrainWorker pid=30226) * Remote wall time: 441.45ms min, 5.08s max, 1.52s mean, 30411.47s total
(RayTrainWorker pid=30226) * Remote cpu time: 253.41ms min, 1.88s max, 523.54ms mean, 10470.78s total
(RayTrainWorker pid=30226) * Peak heap memory usage (MiB): 521.66 min, 15576.22 max, 8021 mean
(RayTrainWorker pid=30226) * Output num rows: 64 min, 65 max, 64 mean, 1281167 total
(RayTrainWorker pid=30226) * Output size bytes: 38535808 min, 39137928 max, 38570941 mean, 771418834040 total
(RayTrainWorker pid=30226) * Tasks per node: 20000 min, 20000 max, 20000 mean; 1 nodes used
(RayTrainWorker pid=30226) 
(RayTrainWorker pid=30226) Stage 2 randomize_block_order: 20000/20000 blocks executed in 0.38s
(RayTrainWorker pid=30226) * Remote wall time: 441.45ms min, 5.08s max, 1.52s mean, 30411.47s total
(RayTrainWorker pid=30226) * Remote cpu time: 253.41ms min, 1.88s max, 523.54ms mean, 10470.78s total
(RayTrainWorker pid=30226) * Peak heap memory usage (MiB): 521.66 min, 15576.22 max, 8021 mean
(RayTrainWorker pid=30226) * Output num rows: 64 min, 65 max, 64 mean, 1281167 total
(RayTrainWorker pid=30226) * Output size bytes: 38535808 min, 39137928 max, 38570941 mean, 771418834040 total
(RayTrainWorker pid=30226) * Tasks per node: 20000 min, 20000 max, 20000 mean; 1 nodes used

Also while using “DummyTrainer like logging”, I see the following in the training loop

(RayTrainWorker pid=30229) Time to read all data in epoch 1669.1098881141224 seconds
(RayTrainWorker pid=30229) P50/P95/Max batch delay (s) 0.06657816699589603 1.287668951204978 3.736117245978676
(RayTrainWorker pid=30229) Num batches read 5005
(RayTrainWorker pid=30229) Average number of rows in each batch:  31.997002997002998
(RayTrainWorker pid=30229) Num bytes read 22999.34 MiB
(RayTrainWorker pid=30229) Mean Dataset throughput 13.78 MiB/s

@chengsu can you take a look here?

Hey @vgill, yeah looks like the slowdown is coming from the read->map_batches step based on the stats you shared.

Happy to jump on a call sometime and we can profile this together. You can find me on the Ray slack!