I am attempting to utilize Ray v2.23 for batch inferencing, specifically on multi-modal data, by leveraging LMMs
dataset = ray.data.read_parquet("file_path")
class MyPredictor:
def __init__(self):
self.my_model = MyModel(model_path="<model_path>",
tensor_parallel_size=1)
def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, list]:
try:
start_time = time.time()
inputs = [{"input": input, "data": {
"image": Image.open(io.BytesIO(base64.b64decode(batch["<image_column_name>"][i])))}} for i in
range(len(batch["<image_column_name>"]))]
predictions = self.my_model.generate(
inputs, sampling_params="<sampling_params>")
batch["<output_label>"] = [pred.outputs[0].text for pred in predictions]
end_time = time.time()
print(f'Total Inference Time for {len(inputs)} - {end_time - start_time}')
except OSError as os_error:
print(f"OS error: {os_error}")
batch["<output_label>"] = ["" for _ in range(len(batch["<image_column_name>"]))]
except Exception as error:
print(f"Misc error: {error}")
batch["<output_label>"] = ["" for _ in range(len(batch["<image_column_name>"]))]
finally:
del batch['<image_bytes_column>']
return batch
dataset = dataset.map_batches(
MyPredictor,
concurrency=int("<num_workers>") * int("<num_gpus>"),
batch_size=int("<batch_size>"),
num_gpus=1
I have observed an issue where, if an exception arises while executing an item in the batch, the pending items from the current batch accumulate in the next batch of the succeeding task. This causes subsequent tasks to fail due to overflow. Can anyone identify what I might be overlooking?
I am trying to find a solution that allows us to skip the problematic item in the batch and proceed with processing the remaining items. While I have considered skipping the entire batch if an exception occurs, this does not resolve the overflow issue.