alright since last time someone pointed me towards this post:
which was very useful, now i’ve given it at another shot.
I’ve used a custom image dataloader since i’m dealing with multichannel images which are not supported by the current dataloader. Subsequently, I created a random access dataset where i store the images (because i’d like to access those multiple times to perform weighted sampling) , here is my code. Is this sensible?
import numpy as np
import ray
import tifffile
from ray.data.datasource import ImageDatasource
from ray.data._internal.delegating_block_builder import DelegatingBlockBuilder
from ray.data.block import Block
from typing import Iterator
import io
import pandas as pd
class MutitiffDatasource(ImageDatasource):
"""A modified image datasource for reading multipage tiff files."""
def _read_stream(self, f: "pyarrow.NativeFile", path: str) -> Iterator[Block]:
data = f.readall()
if path.lower().endswith((".tif", ".tiff")):
image = read_tiff(io.BytesIO(data))
builder = DelegatingBlockBuilder()
array = np.array(image)
item = {"image": array.astype(np.float32)}
builder.add(item)
block = builder.build()
yield block
def read_tiff(tiff_path: bytes):
"""Helper function to read in multi-frame TIFF files, convert them to NumPy.
This function reads a multi-frame TIFF and converts into a NumPy array.
It then stacks these arrays along a new axis (axis=0),
Creating a NumPy array, each array represents a frame from the TIFF.
Args:
tiff_path: The file path to the multi-frame TIFF file.
The path should be provided as a byte string.
Returns:
A NumPy array, each array represents a frame from the TIFF.
"""
return tifffile.imread(tiff_path)
ds = (ray.data.read_api.read_datasource(MutitiffDatasource(paths="local:///data",
include_paths = True),
)
)
mds = ray.data.read_csv('local:///data/dataset.csv')
def create_imagepath(batch):
batch['path'] = batch['path']
return {'image_path': batch['path'], 'image': batch['image']}
def get_images(batch):
"""Returns images from the random access dataset."""
batch["image"] = [record['image'] for record in rmap.multiget(batch['image_path'])]
return batch
ds = ds.map_batches(create_imagepath, batch_format="pandas")
rmap = ds.to_random_access_dataset(key="image_path", num_workers=4)
mds = mds.map_batches(get_images, batch_format="pandas")
full_ds = mds.take_batch()
print(full_ds)