How to match the inference result after the dataset id and batch map

  • High: It blocks me to complete my task.

for the example :Image Classification Batch Inference with PyTorch — Ray 2.9.0

import numpy as np
from typing import Any, Dict

class ResnetModel:
    def __init__(self):
        self.weights = ResNet152_Weights.IMAGENET1K_V1
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = models.resnet152(weights=self.weights).to(self.device)

    def __call__(self, batch: Dict[str, np.ndarray]):
        # Convert the numpy array of images into a PyTorch tensor.
        # Move the tensor batch to GPU if available.
        torch_batch = torch.from_numpy(batch["transformed_image"]).to(self.device)
        with torch.inference_mode():
            prediction = self.model(torch_batch)
            predicted_classes = prediction.argmax(dim=1).detach().cpu()
            predicted_labels = [
                self.weights.meta["categories"][i] for i in predicted_classes
            return {
                "predicted_label": predicted_labels,
                "original_image": batch["original_image"],
def preprocess_image(row: Dict[str, np.ndarray]):
    return {
        "original_image": row["image"],
        "transformed_image": transform(row["image"]),

transformed_ds =
predictions = transformed_ds.map_batches(
    concurrency=4,  # Use 4 GPUs. Change this number based on the number of GPUs in your cluster.
    num_gpus=1,  # Specify 1 GPU per model replica.
    batch_size=720,  # Use the largest batch size that can fit on our GPUs

the predictions result is batch result,
How to match the inference result and the dataset id ,

for item in ds :

Can you clarify what your goal is here? As you said, map_batches() returns batches of data, so you will need to iterate over them. Maybe an example use case would be:

for batch in predictions: # batch is a `List[Dict[str, Any]]`
    ids = [row["id"] for row in batch]
    predicted_labels = [row["predicted_label"] for row in batch]

batch map only support 【pyarrow.Table pandas.DataFrame numpy.ndarray】,my data like :
id int8
image float shape[100, 221, 7]
ext1 string
ext2 string[2]
i need ext1 ext2 for postprocessing
so,I need use pyarrow.Table or pandas.DataFrame inlucde ext1 and ext2
but both of this type only support 1d array,the image is 3d

map_batches() also accepts the default type for a batch, Dict[str, np.ndarray], which should support the data that you specify, no?