This is my training code with ray datasets:
# NOTE read this: https://docs.ray.io/en/latest/train/dl_guide.html#porting-code-to-ray-train
import numpy as np
import torch
from torchvision import transforms
# from ray.data.preprocessors import TorchVisionPreprocessor, Chain, LabelEncoder
from ray.data.preprocessors import Chain, LabelEncoder
from ray.data.datasource.partitioning import PartitionStyle
import ray
from ray import train
from ray.air import session, RunConfig, Checkpoint, CheckpointConfig
from ray.train.torch import TorchCheckpoint
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from ray.train.torch import TorchTrainer
from ray.air.config import ScalingConfig, DatasetConfig
from ray.air.integrations.mlflow import MLflowLoggerCallback
from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present
from ray.tune.logger import TBXLoggerCallback
import torchvision.models as models
import torchmetrics
import time
import pandas as pd
from io import BytesIO
from PIL import Image
from ray.air.util.tensor_extensions.pandas import TensorArray
from ray.data import ActorPoolStrategy
import sys
from ray.data.dataset import Dataset
from typing import Tuple, List, Dict
import os
from ray.data.preprocessor import Preprocessor
from pandas import DataFrame
import json
# NOTE should read files in parallel after figuring out paths
# read parallelism criteria: https://docs.ray.io/en/latest/data/creating-datasets.html#read-parallelism
# dataset = ray.data.read_images("s3://anonymous@air-example-data-2/imagenet-sample-images")
from ray.data.datasource.partitioning import Partitioning
# From https://pytorch.org/vision/main/_modules/torchvision/datasets/folder.html#ImageFolder
# This code is used to create the sysnset.json
def find_classes(directory: str) -> Tuple[List[str], Dict[str, int]]:
"""Finds the class folders in a dataset.
See :class:`DatasetFolder` for details.
"""
classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())
if not classes:
raise FileNotFoundError(f"Couldn't find any class folder in {directory}.")
class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
return classes, class_to_idx
f_synset = open("./synset.json", "r")
synset_dict = json.load(f_synset)
classes = synset_dict["classes"]
class_to_idx = synset_dict["class_to_idx"]
f_synset.close()
train_root = "/home/ec2-user/RAY_GIT/Imagenet/imagenet/train"
#train_root = "s3://a9vs-photon-model-training-exp/imagenet/imagenet/train"
train_partitioning = Partitioning(PartitionStyle.DIRECTORY, field_names=["category"], base_dir=train_root)
train_dataset = ray.data.read_binary_files(paths=train_root, partitioning=train_partitioning, parallelism=20000)
# train_dataset_pipe = ray.data.read_binary_files(paths=train_root, partitioning=train_partitioning).window(bytes_per_window=40e9)
print(train_dataset)
# print(train_dataset.schema())
total_train_dataset_len = train_dataset.count()
print(train_dataset.count())
# print(train_dataset.default_batch_format())
# print(train_dataset.fully_executed().size_bytes())
val_root = "/home/ec2-user/RAY_GIT/Imagenet/imagenet/val"
#val_root = "s3://a9vs-photon-model-training-exp/imagenet/imagenet/val"
val_partitioning = Partitioning(PartitionStyle.DIRECTORY, field_names=["category"], base_dir=val_root)
val_dataset = ray.data.read_binary_files(paths=val_root, partitioning=val_partitioning)
print(val_dataset)
# https://docs.ray.io/en/latest/ray-air/preprocessors.html#implementing-custom-preprocessors
# class CustomPreprocessor(BatchMapper):
class CustomPreprocessor(Preprocessor):
_is_fittable = False
def __init__(
self,
classes,
class_to_idx):
self.classes = classes
self.class_to_idx = class_to_idx
"""
def _fit(self, dataset: Dataset) -> Preprocessor:
self.stats_ = dataset.aggregate(Max("value"))
"""
def _transform_pandas(self, item_list: DataFrame) -> DataFrame:
# we got class 'pandas.core.frame.DataFrame'
# print(type(item_list))
# print(item_list)
byte_items = item_list['bytes']
categories = item_list['category']
# print(len(byte_items), len(categories))
preprocess = transforms.Compose([
# transforms.Resize(256),
# transforms.CenterCrop(224),
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
images = [Image.open(BytesIO(byte_item)).convert('RGB') for byte_item in byte_items]
images = [preprocess(image) for image in images]
images = [np.array(image) for image in images]
categories = [self.class_to_idx[category] for category in categories]
return pd.DataFrame({"image": TensorArray(images), "category": categories})
custom_preprocessor = CustomPreprocessor(classes, class_to_idx)
# There needs to be a validate function too
# def train_epoch(train_dataset_shard, model, criterion, optimizer, **kwargs):
def train_epoch(epoch, train_dataset_shard, model, criterion, optimizer, **kwargs):
"""
OBSOLETE
batch_time = AverageMeter('Time', ':6.3f')
data_time = AverageMeter('Data', ':6.3f')
losses = AverageMeter('Loss', ':.4e')
top1 = AverageMeter('Acc@1', ':6.2f')
top5 = AverageMeter('Acc@5', ':6.2f')
progress = ProgressMeter(
len(train_loader),
[batch_time, data_time, losses, top1, top5],
prefix="Epoch: [{}]".format(epoch))
"""
rank = session.get_world_rank()
batch_size = kwargs['batch_size']
mape = kwargs['mape']
mean_validation_loss = kwargs['mean_validation_loss']
acc1 = kwargs['acc1']
# acc5 = kwargs['acc5']
total_train_dataset_len = kwargs['total_train_dataset_len']
# switch to train mode
model.train()
end = time.time()
# NOTE: when you load from a dataloader, use the following
# size = len(dataloader.dataset) // session.get_world_size()
# The following does not work: ValueError("Cannot count a pipeline of infinite length.")
# ds_shard_size = train_dataset_shard.count()
ds_shard_size = total_train_dataset_len // session.get_world_size()
train_dataset_batches = train_dataset_shard.iter_torch_batches(
device=train.torch.get_device(),
# Works for ray3.0 only i guess, check if this is needed as per https://github.com/ray-project/ray/pull/31692/files
prefetch_blocks=3,
batch_size=batch_size, # default is 256, use None to set entire block as batch
local_shuffle_buffer_size=batch_size * 2 # should be greater than equal to batch size
)
start = time.perf_counter()
batches_read, bytes_read = 0, 0
batch_start = time.perf_counter()
batch_delays = []
batch_sizes = []
# print("train_dataset_shard_batches: ", train_dataset_shard_batches)
print("train_dataset_batches: ", train_dataset_batches)
for i, batch in enumerate(train_dataset_batches):
batch_delay = time.perf_counter() - batch_start
batch_delays.append(batch_delay)
batches_read += 1
# get the images and classes
images, target = batch["image"], batch["category"]
bytes_read += (images.size(0) * (224 * 224 * 3 + 64))
batch_sizes.append(images.size(0))
# TODO
# measure data loading time
# data_time.update(time.time() - end)
"""
OBSOLETE
# move data to the same device as model
images = images.to(device, non_blocking=True)
target = target.to(device, non_blocking=True)
"""
# compute output
output = model(images)
loss = criterion(output, target)
batch_time = time.time() - end
end = time.time()
if i and i % 20 == 0:
loss_item, current = loss.item(), i * images.size(0)
print(f"loss: {loss_item:>7f} [{current:>5d}/{ds_shard_size:>5d}], batch_time: {batch_time}")
# Debugging for dataset and datasetpipeline https://docs.ray.io/en/latest/data/performance-tips.html#debugging-statistics
print(train_dataset_shard.stats())
# compute gradient and do SGD step
optimizer.zero_grad()
# loss.backward()
# AMP
train.torch.backward(loss)
optimizer.step()
batch_start = time.perf_counter()
delta = time.perf_counter() - start
print("Epoch: {}, Training Stats for rank: ".format(epoch), rank)
total_time_reading = np.sum(batch_delays)
print("Time to read all data in epoch", total_time_reading, "seconds")
print(
"P50/P95/Max batch delay (s)",
np.quantile(batch_delays, 0.5),
np.quantile(batch_delays, 0.95),
np.max(batch_delays),
)
print("Num batches read", batches_read)
print("Average number of rows in each batch: ", np.mean(batch_sizes))
print("Num bytes read", round(bytes_read / (1024 * 1024), 2), "MiB")
print(
"Mean Dataset throughput", round(bytes_read / (1024 * 1024) / total_time_reading, 2), "MiB/s"
)
def validate_epoch(epoch, validation_dataset_shard, model, criterion, **kwargs):
batch_size = kwargs['batch_size']
mape = kwargs['mape']
mean_validation_loss = kwargs['mean_validation_loss']
acc1 = kwargs['acc1']
# acc5 = kwargs['acc5']
def run_validate(validation_dataset_shard, base_progress=0):
rank = session.get_world_rank()
with torch.no_grad():
end = time.time()
validation_dataset_batches = Dataset.iter_torch_batches(
validation_dataset_shard,
prefetch_blocks=10,
batch_size=batch_size, # default is 256, use None to set entire block as batch
# drop_last=drop_last,
local_shuffle_buffer_size=batch_size * 2, # should be greater than equal to batch size
device=train.torch.get_device()
# Works for ray3.0 only i guess, check if this is needed as per https://github.com/ray-project/ray/pull/31692/files
# local_shuffle_seed=local_shuffle_seed,
)
running_loss = 0.
start = time.perf_counter()
batches_read, bytes_read = 0, 0
batch_start = time.perf_counter()
batch_delays = []
batch_sizes = []
for i, batch in enumerate(validation_dataset_batches):
batch_delay = time.perf_counter() - batch_start
batch_delays.append(batch_delay)
batches_read += 1
# get the images and classes
images, target = batch["image"], batch["category"]
i = base_progress + i
bytes_read += (images.size(0) * (224 * 224 * 3 + 64))
batch_sizes.append(images.size(0))
# compute output
output = model(images)
loss = criterion(output, target)
# measure accuracy and record loss
# save loss in aggregator
running_loss += loss.item()
mean_validation_loss(loss)
mape(torch.argmax(output, dim=1), target)
acc1(torch.argmax(output, dim=1), target)
# acc5(torch.argmax(output, dim=1), target)
# measure elapsed time
batch_time = time.time() - end
end = time.time()
batch_start = time.perf_counter()
print("Epoch: {}, Validation Stats for rank: ".format(epoch), rank)
total_time_reading = np.sum(batch_delays)
print("Time to read all data in epoch", total_time_reading, "seconds")
print(
"P50/P95/Max batch delay (s)",
np.quantile(batch_delays, 0.5),
np.quantile(batch_delays, 0.95),
np.max(batch_delays),
)
print("Num batches read", batches_read)
print("Average number of rows in each batch: ", np.mean(batch_sizes))
print("Num bytes read", round(bytes_read / (1024 * 1024), 2), "MiB")
print(
"Mean Dataset throughput", round(bytes_read / (1024 * 1024) / total_time_reading, 2), "MiB/s"
)
valid_loss_collected = running_loss / i # i == num_batches
return valid_loss_collected
# switch to evaluate mode
model.eval()
validation_loss_collected = run_validate(validation_dataset_shard)
# collect all metrics
# use .item() to obtain a value that can be reported
mape_collected = mape.compute().item()
mean_validation_loss_collected = mean_validation_loss.compute().item()
acc1_collected = acc1.compute().item()
# acc5_collected = acc5.compute().item()
# TODO
acc5_collected = None
return mape_collected, validation_loss_collected, mean_validation_loss_collected, acc1_collected, acc5_collected
# We import Ray Train and Ray AIR Session:
# We use a config dict to configure some hyperparameters (this is not strictly needed but good practice, especially if you want to o hyperparameter tuning later):
# https://docs.ray.io/en/latest/ray-air/package-ref.html#ray.train.data_parallel_trainer.DataParallelTrainer
# Useful example code:
# 1) https://docs.ray.io/en/latest/ray-air/examples/torch_image_example.html
# 2) https://docs.ray.io/en/latest/train/examples/pytorch/torch_fashion_mnist_example.html
# 3) https://docs.ray.io/en/latest/ray-air/check-ingest.html#disabling-preprocessor-transforms
# 4) https://docs.ray.io/en/latest/train/dl_guide.html#distributed-data-ingest-with-ray-datasets
# 5) https://docs.ray.io/en/latest/ray-air/examples/convert_existing_pytorch_code_to_ray_air.html
def train_loop_per_worker(config: dict):
# AMP
train.torch.accelerate(amp=True)
device = train.torch.get_device()
rank = session.get_world_rank()
# Reproducibility as per https://docs.ray.io/en/latest/train/dl_guide.html#reproducibility
# train.torch.enable_reproducibility()
# Parameters
batch_size = config["batch_size"]
lr = config["lr"]
momentum = config["momentum"]
weight_decay = config["weight_decay"]
epochs = config["num_epochs"]
total_train_dataset_len = config["total_train_dataset_len"]
# We dynamically adjust the worker batch size according to the number of workers:
batch_size_per_worker = batch_size // session.get_world_size()
print("=> creating model '{}'".format("resnet18"))
model = models.__dict__["resnet18"]()
# train.torch.prepare_model() also automatically takes care of setting up devices (e.g. GPU training) - so we can get rid of those lines in our current code!
model = train.torch.prepare_model(model)
# model.cuda(args.gpu) <== not needed
# define loss function (criterion), optimizer, and learning rate scheduler
# TODO hopefully ray train takes care of setting up the device properly
criterion = nn.CrossEntropyLoss()
# criterion = nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.SGD(model.parameters(), lr,
momentum=momentum,
weight_decay=weight_decay)
# AMP
optimizer = train.torch.prepare_optimizer(optimizer)
# Get worker shards over data
# Datasets can be accessed in your train_func via ``get_dataset_shard``.
# This has become a pipeline due to streaming in the dataset config
train_dataset_shard = session.get_dataset_shard("train")
# This is a dataset
validation_dataset_shard = session.get_dataset_shard("validation")
print("train_dataset_shard: ", train_dataset_shard)
print("validation_dataset_shard: ", validation_dataset_shard)
"""Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
scheduler = StepLR(optimizer, step_size=30, gamma=0.1)
# NOTE: metrics
# measure accuracy and record loss
mape = torchmetrics.MeanAbsolutePercentageError().to(device)
# for averaging loss
mean_validation_loss = torchmetrics.MeanMetric().to(device)
acc1 = torchmetrics.Accuracy(task="multiclass", num_classes=1000, top_k=1).to(device)
# TODO
acc5 = None
# acc5 = torchmetrics.Accuracy(task="multiclass", num_classes=1000, top_k=5).to(device)
start_epoch = 0
checkpoint = session.get_checkpoint()
if checkpoint:
# assume that we have run the session.report() example
# and successfully save some model weights
checkpoint_dict = checkpoint.to_dict()
model.load_state_dict(checkpoint_dict.get("model_weights"))
start_epoch = checkpoint_dict.get("epoch", -1) + 1
epochs_read = 0
# we will iterate over epoch delimited pipeline segments, streaming will create windows of blocks whose order will change within a worker, data is nicely
# split between workers for one epoch, though all workers see all data over multiple epochs. This seems to be constrast to pytorch DataParallel with shuffle False where one worker see the same data for every epoch
# reference : https://github.com/ray-project/ray/issues/18287
# for epoch, train_dataset_shard_batches in enumerate(train_dataset_shard.iter_epochs()):
total_time_in_training = 0.
for epoch in range(start_epoch, epochs):
print(" Epoch ", epoch)
epochs_read += 1
# train for one epoch
# train_epoch(train_dataset_shard, model, criterion, optimizer, batch_size=batch_size, mape=mape,
epoch_start_time = time.time()
train_epoch(epoch, train_dataset_shard, model, criterion, optimizer, batch_size=batch_size_per_worker, mape=mape,
mean_validation_loss=mean_validation_loss, acc1=acc1, acc5=acc5,
total_train_dataset_len=total_train_dataset_len)
training_epoch_time = time.time() - epoch_start_time
total_time_in_training += training_epoch_time
print("Training time in epoch for rank {} = {}".format(rank, training_epoch_time))
validation_start_time = time.time()
mape_collected, validation_loss_collected, mean_validation_loss_collected, acc1_collected, acc5_collected = validate_epoch(
epoch, validation_dataset_shard, model, criterion, batch_size=batch_size, mape=mape,
mean_validation_loss=mean_validation_loss, acc1=acc1, acc5=acc5)
print("Validation time in epoch for rank {} = {}".format(rank, time.time() - validation_start_time))
# checkpoint for fault tolerance
state_dict = model.state_dict()
consume_prefix_in_state_dict_if_present(state_dict, "module.")
checkpoint = Checkpoint.from_dict(
# TODO, where is loss coming from, perhaps from validation
dict(epoch=epoch, model_weights=state_dict, mean_validation_loss_collected=mean_validation_loss_collected)
)
# Report session and checkpoint
session.report(
{
"mape_collected": mape_collected,
"validation_loss_collected": validation_loss_collected,
"mean_validation_loss_collected": mean_validation_loss_collected,
"acc1_collected": acc1_collected,
# "acc5_collected": acc5_collected
},
checkpoint=checkpoint
)
# reset for next epoch
mape.reset()
mean_validation_loss.reset()
acc1.reset()
# acc5.reset()
scheduler.step()
print("******************* Total time in training: ", total_time_in_training)
# NOTE https://docs.ray.io/en/latest/train/dl_guide.html#configuring-checkpoints
# Keep the 2 checkpoints with the smallest "mean_validation_loss_collected" value.
checkpoint_config = CheckpointConfig(
num_to_keep=2, checkpoint_score_attribute="mean_validation_loss_collected", checkpoint_score_order="min"
)
# NOTE: we just use this preprocesor to convert categories to integers
dataset_preprocessor = LabelEncoder(label_column="category")
# NOTE refer to https://docs.ray.io/en/latest/train/dl_guide.html#configuring-training
# https://docs.ray.io/en/latest/train/api.html#pytorch
trainer = TorchTrainer(
train_loop_per_worker=train_loop_per_worker,
train_loop_config={"num_epochs": 1, "batch_size": 256, "lr": 0.1, "momentum": 0.9, "weight_decay": 1e-4,
"total_train_dataset_len": total_train_dataset_len},
# https://docs.ray.io/en/latest/train/dl_guide.html#example-logging-to-mlflow-and-tensorboard
datasets={"train": train_dataset, "validation": val_dataset},
# To scale your training script, create a Ray Cluster and increase the number of workers. If your cluster contains GPUs, add "use_gpu": True to your scaling config.
# scaling_config=ScalingConfig(num_workers=8, use_gpu=True, trainer_resources={"CPU": 20}, resources_per_worker={"CPU": 2, "GPU": 1}),
scaling_config=ScalingConfig(num_workers=8, use_gpu=True, resources_per_worker={"CPU": 2, "GPU": 1}),
run_config=RunConfig(
checkpoint_config=checkpoint_config,
callbacks=[
MLflowLoggerCallback(experiment_name="train_experiment"),
TBXLoggerCallback(),
],
# https://docs.ray.io/en/latest/ray-air/package-ref.html#ray.air.config.RunConfig
verbose=3,
log_to_file=True
),
dataset_config={
"train": DatasetConfig(fit=True, split=True, global_shuffle=False),
},
# preprocessor = dataset_preprocessor,
# preprocessor = preprocessor,
preprocessor=custom_preprocessor
# if you wish to resume from a previously saved checkpoint
# resume_from_checkpoint=result.checkpoint,
)
# NOTE: interpret training results https://docs.ray.io/en/latest/ray-air/trainer.html#how-to-interpret-training-results
result = trainer.fit()