Challenges and Questions on Implementing Custom Sampling Strategy for CNN Training with Large Multichannel TIFF Files in Ray

Title:

Challenges and Questions on Implementing Custom Sampling Strategy for CNN Training with Large Multichannel TIFF Files in Ray

Body:

Context:

I am working on a project where we train a Convolutional Neural Network (CNN) using large, multichannel TIFF files stored in an S3 bucket (totaling terabytes of data). The metadata, including class labels and image paths, is stored in a Parquet file. Given the unbalanced nature of our dataset, we are exploring a custom sampling strategy to use with Ray, as it seems Ray does not directly support PyTorch’s weighted sampler.

Strategy Overview:

Our proposed strategy involves duplicating metadata file rows (since they are comparatively small) and copying data locally to the tmp folder. This approach aims to avoid repeated data retrieval from the S3 bucket across epochs. Although I’ve implemented a basic version of this strategy and attempted optimization with map_batches, performance improvements were not as expected.

Questions and Issues:

  1. Strategy Validation: Is the approach I’ve outlined above recommended for our use case, or are there aspects I’m overlooking that might suggest a different strategy?
  2. Batch Size Management: How can I effectively control the batch size to align with my GPU’s capacity (approx. 32 images, with each image being ~2 to ~6 MB)? For instance, setting a batch size of 4 unexpectedly results in writing 26 files, as all workers still fetch files in blocks but only filter and return 4. This might not pose a problem for a small number of files, but scalability becomes a concern with larger datasets and more files. In my pipeline test with a batch size of 8, the first batch led to ~96 file writes, significantly extending the duration of the first epoch (40 minutes) compared to subsequent epochs (8 minutes). However, GPU utilization was only at 30% for the first epoch and increased to 60% from the second epoch onwards. Attempting to increase the batch size for better resource utilization led to cluster crashes and slowdowns.

Seeking Guidance:

I am looking for insights or recommendations on optimizing this process, particularly regarding efficient data handling and improving GPU utilization without compromising system stability. Any advice on better integrating Ray’s capabilities with our project’s requirements would be greatly appreciated.

import numpy as np
import os

import ray 
import shutil


import ray.data
import torch
import tifffile

@ray.remote
class Counter:
    def __init__(self):
        self.count = 0

    def increment(self):
        self.count += 1
        return self.count

    def get_count(self):
        return self.count

def save_image(row, path, counter):
    """
    Generates and saves an image.
    """
    print(f"Saving image: {row['file']}")
    os.makedirs(path, exist_ok=True)
    file = row['file']
    im_path = os.path.join(path, file)
    print(im_path)
    if os.path.exists(im_path):
        print(f'skip write: {im_path}')
        row['file'] = im_path
        return row

    image_data = np.random.randint(0, 256, size=(100, 100), dtype=np.uint8)
    tifffile.imwrite(im_path, image_data)
    ray.get(counter.increment.remote())
    row['file'] = im_path

    return row



def read_image(row, counter):
    """
    Reading an image.
    """
    print(f"Reading image from {row['file']}")
    row['image'] = tifffile.imread(row['file'])
    ray.get(counter.increment.remote())
    return row

def teardown(file_dir):
    """
    Removes the directory containing the images.
    """
    if not os.path.exists(file_dir):
        os.makedirs(file_dir, exist_ok=True)
    else:
        for filename in os.listdir(file_dir):
            file_path = os.path.join(file_dir, filename)
            try:
                if os.path.isfile(file_path) or os.path.islink(file_path):
                    os.unlink(file_path)
                elif os.path.isdir(file_path):
                    shutil.rmtree(file_path)
            except Exception as e:
                print('Failed to delete %s. Reason: %s' % (file_path, e))

def collate_fn(batch):
    """Covert images to tensor"""
    print(f"Creating a batch:{[im for im in batch['file']]}")
    images_tensor = [torch.tensor(im.astype(float), dtype=torch.float32) for im in batch['image']]
    return {'images': images_tensor}


FILE_PATTH = '/tmp/alphabet'
MAX_EPOCHS = 5
BATCH_SIZE = 4
COCURENCY = 5
BLOCK_SIZE = 4

ray.init()
ctx = ray.data.DataContext.get_current()
ctx.execution_options.verbose_progress = True

upper_ascii_range = range(65, 91, 1)
files = [f"{chr(i)}.tiff" for i in upper_ascii_range]
print(f'the dataset contains: {files}')

ds = ray.data.from_items([{"file": file} for file in files], 
                         override_num_blocks =BLOCK_SIZE)
write_counter = Counter.remote()
read_counter = Counter.remote()

ds = (ds
    .map(save_image, fn_kwargs={"path": FILE_PATTH, "counter": write_counter}, concurrency=COCURENCY)
    .map(read_image, fn_kwargs={"counter": read_counter}, concurrency=COCURENCY)
)

dataloader = ds.iter_torch_batches(batch_size= BATCH_SIZE,
                                   drop_last=True,
                                   collate_fn= collate_fn)
for epch in range(MAX_EPOCHS):
    print(f"starting epoch: {epch}")
    for batch in dataloader:
        num_writes = ray.get(write_counter.get_count.remote())
        num_reads = ray.get(read_counter.get_count.remote())
        print(f"there have been reads/writes: ({num_reads}/{num_writes})  so far")
        pass
downloaded_images = len(os.listdir(FILE_PATTH))
total_writes = ray.get(write_counter.get_count.remote())
total_reads = ray.get(read_counter.get_count.remote())
print(f"Total writes to the directory: {total_writes}\n")
print(f"Total reads to the directory: {total_reads}\n")
print(f"There are {downloaded_images} images downloaded in {FILE_PATTH}\n")

teardown(file_dir=FILE_PATTH)

alright since last time someone pointed me towards this post:

which was very useful, now i’ve given it at another shot.

I’ve used a custom image dataloader since i’m dealing with multichannel images which are not supported by the current dataloader. Subsequently, I created a random access dataset where i store the images (because i’d like to access those multiple times to perform weighted sampling) , here is my code. Is this sensible?

import numpy as np
import ray 
import tifffile
from ray.data.datasource import ImageDatasource
from ray.data._internal.delegating_block_builder import DelegatingBlockBuilder
from ray.data.block import Block 
from typing import Iterator
import io
import pandas as pd

class MutitiffDatasource(ImageDatasource):
    """A modified image datasource for reading multipage tiff files."""
    def _read_stream(self, f: "pyarrow.NativeFile", path: str) -> Iterator[Block]:
        data = f.readall()
        if path.lower().endswith((".tif", ".tiff")):
            image = read_tiff(io.BytesIO(data))   
            builder = DelegatingBlockBuilder()
            array = np.array(image)
            item = {"image": array.astype(np.float32)}
            builder.add(item)
            block = builder.build()

            yield block

def read_tiff(tiff_path: bytes):
    """Helper function to read in multi-frame TIFF files, convert them to NumPy.

    This function reads a multi-frame TIFF  and converts into a NumPy array.
    It then stacks these arrays along a new axis (axis=0),
    Creating a NumPy array, each array represents a frame from the TIFF.


    Args:
        tiff_path: The file path to the multi-frame TIFF file.
        The path should be provided as a byte string.

    Returns:
        A  NumPy array, each array represents a frame from the TIFF.
    """
    
    return tifffile.imread(tiff_path)



ds = (ray.data.read_api.read_datasource(MutitiffDatasource(paths="local:///data",
                                                           include_paths = True),
                                                        )
)
mds = ray.data.read_csv('local:///data/dataset.csv')

def create_imagepath(batch):
    batch['path'] = batch['path']
    return {'image_path': batch['path'], 'image': batch['image']}

def get_images(batch):
    """Returns images from the random access dataset."""
    batch["image"] = [record['image'] for record in rmap.multiget(batch['image_path'])]
    return batch

ds = ds.map_batches(create_imagepath, batch_format="pandas")
rmap = ds.to_random_access_dataset(key="image_path", num_workers=4)

mds = mds.map_batches(get_images, batch_format="pandas")
full_ds = mds.take_batch()
print(full_ds)

Update: Storing my dataset with paths as keys as a random access dataset does not scale and is definitely not performant enough to access the files quickly during training :slight_smile: (it took 20 minutes to create the dataset, and then 10 minutes a single batch to return. )

In conclusion, to develop a weighted sampling approach, I experimented with several methods:

  1. Dataset Balancing: I balanced the dataset by upsampling and downsampling specific classes using a dataframe containing paths.
  2. Image Retrieval and Processing: I attempted to use ray map/map_batches to retrieve images from S3 and save them locally, which initially created a performance bottleneck. The first epoch took 40 minutes, while subsequent epochs took 8 minutes. GPU utilization was approximately 40% in the first epoch and increased to around 70% thereafter. However, increasing the batch size was challenging due to limited control over file outputs per block.
  3. TIFF File Handling: Using a dedicated tifffile reader with the ray image datasource worked well alone, but integrating it with my labels was problematic. I explored using Dataset.zip to merge the metadata with the image dataset, but this was time-consuming, taking nearly 30 minutes due to the synthetic balancing of paths.
  4. Random Access Dataset: I tried storing images and metadata in a random access dataset with four workers, but this was inefficient, taking about 25 minutes, and caused issues with object spilling. Both methods eventually crashed during longer runs.
  5. Torch DataLoader with Boto3: Comparatively, using the torch dataloader integrated with boto3 was more efficient, taking only 15 minutes per epoch and achieving much higher GPU utilization.

These trials highlighted the advantages of integrating torch dataloader with boto3 for this application, offering significant improvements in speed and GPU efficiency.

1 Like