Challenges and Questions on Implementing Custom Sampling Strategy for CNN Training with Large Multichannel TIFF Files in Ray

alright since last time someone pointed me towards this post:

which was very useful, now i’ve given it at another shot.

I’ve used a custom image dataloader since i’m dealing with multichannel images which are not supported by the current dataloader. Subsequently, I created a random access dataset where i store the images (because i’d like to access those multiple times to perform weighted sampling) , here is my code. Is this sensible?

import numpy as np
import ray 
import tifffile
from ray.data.datasource import ImageDatasource
from ray.data._internal.delegating_block_builder import DelegatingBlockBuilder
from ray.data.block import Block 
from typing import Iterator
import io
import pandas as pd

class MutitiffDatasource(ImageDatasource):
    """A modified image datasource for reading multipage tiff files."""
    def _read_stream(self, f: "pyarrow.NativeFile", path: str) -> Iterator[Block]:
        data = f.readall()
        if path.lower().endswith((".tif", ".tiff")):
            image = read_tiff(io.BytesIO(data))   
            builder = DelegatingBlockBuilder()
            array = np.array(image)
            item = {"image": array.astype(np.float32)}
            builder.add(item)
            block = builder.build()

            yield block

def read_tiff(tiff_path: bytes):
    """Helper function to read in multi-frame TIFF files, convert them to NumPy.

    This function reads a multi-frame TIFF  and converts into a NumPy array.
    It then stacks these arrays along a new axis (axis=0),
    Creating a NumPy array, each array represents a frame from the TIFF.


    Args:
        tiff_path: The file path to the multi-frame TIFF file.
        The path should be provided as a byte string.

    Returns:
        A  NumPy array, each array represents a frame from the TIFF.
    """
    
    return tifffile.imread(tiff_path)



ds = (ray.data.read_api.read_datasource(MutitiffDatasource(paths="local:///data",
                                                           include_paths = True),
                                                        )
)
mds = ray.data.read_csv('local:///data/dataset.csv')

def create_imagepath(batch):
    batch['path'] = batch['path']
    return {'image_path': batch['path'], 'image': batch['image']}

def get_images(batch):
    """Returns images from the random access dataset."""
    batch["image"] = [record['image'] for record in rmap.multiget(batch['image_path'])]
    return batch

ds = ds.map_batches(create_imagepath, batch_format="pandas")
rmap = ds.to_random_access_dataset(key="image_path", num_workers=4)

mds = mds.map_batches(get_images, batch_format="pandas")
full_ds = mds.take_batch()
print(full_ds)