Title:
Challenges and Questions on Implementing Custom Sampling Strategy for CNN Training with Large Multichannel TIFF Files in Ray
Body:
Context:
I am working on a project where we train a Convolutional Neural Network (CNN) using large, multichannel TIFF files stored in an S3 bucket (totaling terabytes of data). The metadata, including class labels and image paths, is stored in a Parquet file. Given the unbalanced nature of our dataset, we are exploring a custom sampling strategy to use with Ray, as it seems Ray does not directly support PyTorch’s weighted sampler.
Strategy Overview:
Our proposed strategy involves duplicating metadata file rows (since they are comparatively small) and copying data locally to the tmp
folder. This approach aims to avoid repeated data retrieval from the S3 bucket across epochs. Although I’ve implemented a basic version of this strategy and attempted optimization with map_batches
, performance improvements were not as expected.
Questions and Issues:
- Strategy Validation: Is the approach I’ve outlined above recommended for our use case, or are there aspects I’m overlooking that might suggest a different strategy?
- Batch Size Management: How can I effectively control the batch size to align with my GPU’s capacity (approx. 32 images, with each image being ~2 to ~6 MB)? For instance, setting a batch size of 4 unexpectedly results in writing 26 files, as all workers still fetch files in blocks but only filter and return 4. This might not pose a problem for a small number of files, but scalability becomes a concern with larger datasets and more files. In my pipeline test with a batch size of 8, the first batch led to ~96 file writes, significantly extending the duration of the first epoch (40 minutes) compared to subsequent epochs (8 minutes). However, GPU utilization was only at 30% for the first epoch and increased to 60% from the second epoch onwards. Attempting to increase the batch size for better resource utilization led to cluster crashes and slowdowns.
Seeking Guidance:
I am looking for insights or recommendations on optimizing this process, particularly regarding efficient data handling and improving GPU utilization without compromising system stability. Any advice on better integrating Ray’s capabilities with our project’s requirements would be greatly appreciated.
import numpy as np
import os
import ray
import shutil
import ray.data
import torch
import tifffile
@ray.remote
class Counter:
def __init__(self):
self.count = 0
def increment(self):
self.count += 1
return self.count
def get_count(self):
return self.count
def save_image(row, path, counter):
"""
Generates and saves an image.
"""
print(f"Saving image: {row['file']}")
os.makedirs(path, exist_ok=True)
file = row['file']
im_path = os.path.join(path, file)
print(im_path)
if os.path.exists(im_path):
print(f'skip write: {im_path}')
row['file'] = im_path
return row
image_data = np.random.randint(0, 256, size=(100, 100), dtype=np.uint8)
tifffile.imwrite(im_path, image_data)
ray.get(counter.increment.remote())
row['file'] = im_path
return row
def read_image(row, counter):
"""
Reading an image.
"""
print(f"Reading image from {row['file']}")
row['image'] = tifffile.imread(row['file'])
ray.get(counter.increment.remote())
return row
def teardown(file_dir):
"""
Removes the directory containing the images.
"""
if not os.path.exists(file_dir):
os.makedirs(file_dir, exist_ok=True)
else:
for filename in os.listdir(file_dir):
file_path = os.path.join(file_dir, filename)
try:
if os.path.isfile(file_path) or os.path.islink(file_path):
os.unlink(file_path)
elif os.path.isdir(file_path):
shutil.rmtree(file_path)
except Exception as e:
print('Failed to delete %s. Reason: %s' % (file_path, e))
def collate_fn(batch):
"""Covert images to tensor"""
print(f"Creating a batch:{[im for im in batch['file']]}")
images_tensor = [torch.tensor(im.astype(float), dtype=torch.float32) for im in batch['image']]
return {'images': images_tensor}
FILE_PATTH = '/tmp/alphabet'
MAX_EPOCHS = 5
BATCH_SIZE = 4
COCURENCY = 5
BLOCK_SIZE = 4
ray.init()
ctx = ray.data.DataContext.get_current()
ctx.execution_options.verbose_progress = True
upper_ascii_range = range(65, 91, 1)
files = [f"{chr(i)}.tiff" for i in upper_ascii_range]
print(f'the dataset contains: {files}')
ds = ray.data.from_items([{"file": file} for file in files],
override_num_blocks =BLOCK_SIZE)
write_counter = Counter.remote()
read_counter = Counter.remote()
ds = (ds
.map(save_image, fn_kwargs={"path": FILE_PATTH, "counter": write_counter}, concurrency=COCURENCY)
.map(read_image, fn_kwargs={"counter": read_counter}, concurrency=COCURENCY)
)
dataloader = ds.iter_torch_batches(batch_size= BATCH_SIZE,
drop_last=True,
collate_fn= collate_fn)
for epch in range(MAX_EPOCHS):
print(f"starting epoch: {epch}")
for batch in dataloader:
num_writes = ray.get(write_counter.get_count.remote())
num_reads = ray.get(read_counter.get_count.remote())
print(f"there have been reads/writes: ({num_reads}/{num_writes}) so far")
pass
downloaded_images = len(os.listdir(FILE_PATTH))
total_writes = ray.get(write_counter.get_count.remote())
total_reads = ray.get(read_counter.get_count.remote())
print(f"Total writes to the directory: {total_writes}\n")
print(f"Total reads to the directory: {total_reads}\n")
print(f"There are {downloaded_images} images downloaded in {FILE_PATTH}\n")
teardown(file_dir=FILE_PATTH)