[Serve] Batch inference

I followed the batch inference tutorial. The only difference is batch_wait_timeout_s. I follow the example and use 9 Ray tasks to send requests to the serve deployment. However, it always prints “Our input array has length: 1” no matter the value of batch_wait_timeout_s (0, 1, 10). Is it expected?

from typing import List

from starlette.requests import Request
from transformers import pipeline

from ray import serve

@serve.deployment
class BatchTextGenerator:
    def __init__(self, pipeline_key: str, model_key: str):
        self.model = pipeline(pipeline_key, model_key)

    @serve.batch(max_batch_size=4, batch_wait_timeout_s=10)
    async def handle_batch(self, inputs: List[str]) -> List[str]:
        print("Our input array has length:", len(inputs))

        results = self.model(inputs)
        return [result[0]["generated_text"] for result in results]

    async def __call__(self, request: Request) -> List[str]:
        return await self.handle_batch(request.query_params["text"])
    
generator = BatchTextGenerator.bind("text-generation", "gpt2")

I ran the example on a RayCluster with the image rayproject/ray=ml:2.3.0 with KubeRay 0.4.0.

Hi, Have you found issues? I have the same problem and I don’t why.