I followed the batch inference tutorial. The only difference is batch_wait_timeout_s
. I follow the example and use 9 Ray tasks to send requests to the serve deployment. However, it always prints “Our input array has length: 1” no matter the value of batch_wait_timeout_s
(0, 1, 10). Is it expected?
from typing import List
from starlette.requests import Request
from transformers import pipeline
from ray import serve
@serve.deployment
class BatchTextGenerator:
def __init__(self, pipeline_key: str, model_key: str):
self.model = pipeline(pipeline_key, model_key)
@serve.batch(max_batch_size=4, batch_wait_timeout_s=10)
async def handle_batch(self, inputs: List[str]) -> List[str]:
print("Our input array has length:", len(inputs))
results = self.model(inputs)
return [result[0]["generated_text"] for result in results]
async def __call__(self, request: Request) -> List[str]:
return await self.handle_batch(request.query_params["text"])
generator = BatchTextGenerator.bind("text-generation", "gpt2")
I ran the example on a RayCluster with the image rayproject/ray=ml:2.3.0
with KubeRay 0.4.0.