How to do batch inference on a dataset?

Hi, I have a large dataset in webdataset format and I want to perform distributed inference using Ray. I have implemented the code for one batch of the dataset, but I’m unsure how to extend it for the entire dataset. Can you please guide me on how to do this as I could not find any docs about this on Ray webpage?

import ray
from PIL import Image
import torch
import numpy as np
from typing import Any, Dict
import clip
from io import BytesIO

dataset_path = "/home/datasets/shards/shard_000000.tar"
device = "cuda" if torch.cuda.is_available() else "cpu"
model, transform = clip.load("ViT-B/32", device=device)


def preprocess_image(row: Dict[str, np.ndarray]):
    key = row["__key__"]
    original_image = Image.open(BytesIO(row["jpeg"]))
    transformed_image = transform(original_image).numpy()
    return {
        "__key__": key,
        "original_image": original_image,
        "transformed_image": transform(original_image),
    }


ds = ray.data.read_webdataset(dataset_path)
transformed_ds = ds.map(preprocess_image)
# single_batch = transformed_ds.take_batch(10)
# print(single_batch.keys())
# print(single_batch["__key__"])
# print(single_batch["original_image"].shape)
# print(single_batch["transformed_image"].shape)


class ResnetModel:
    def __init__(self):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model, _ = clip.load("ViT-B/32", device=self.device)
        self.model.eval()

        # precompute class labels
        self.filters = ["a diagram", "a dog", "a cat"]

        text_inputs = torch.cat([clip.tokenize(filter) for filter in self.filters]).to(
            self.device
        )

        with torch.inference_mode():
            self.text_features = self.model.encode_text(text_inputs)
            self.text_features /= self.text_features.norm(dim=-1, keepdim=True)

    def __call__(self, batch: Dict[str, np.ndarray]):
        # Convert the numpy array of images into a PyTorch tensor.
        # Move the tensor batch to GPU if available.
        image = torch.from_numpy(batch["transformed_image"]).to(self.device)
        with torch.inference_mode():
            image_features = self.model.encode_image(image)
            prediction = (100.0 * image_features @ self.text_features.T).softmax(dim=-1)
            predicted_classes = prediction.argmax(dim=1).detach().cpu()
            predicted_labels = [self.filters[i] for i in predicted_classes]
            return {
                "key": batch["__key__"],
                "predicted_label": predicted_labels,
                # "original_image": batch["original_image"],
            }


BATCH_SIZE = 100

predictions = transformed_ds.map_batches(
    ResnetModel,
    concurrency=3,  # Use 3 GPUs. Change this number based on the number of GPUs in your cluster.
    num_gpus=1,  # Specify 1 GPU per model replica.
    batch_size=BATCH_SIZE,  # Use the largest batch size that can fit on our GPUs
)

prediction_batch = predictions.take_batch(BATCH_SIZE)


index = 1
for key, prediction in zip(
    prediction_batch["key"], prediction_batch["predicted_label"]
):
    print(f"[{index}] - Key: {key}, Label: {prediction}")
    index += 1