I am really enjoying ray data, however have run into an issue which is surprisingly difficult for me to figure out the solution for.
I am trying to perform inference on documents to perform entity extraction. The data comes from postgres table. To do this I do something like:
ds = ray.data.read_sql("SELECT * FROM document_table;")
I then perform semantic chunking:
ds = ds.flat_map(process_document)
Where process_document
takes in a row from document_table
and then splits it into chunks, and then for each chunk keeps the original document’s meta data with it. I use flat_map
here because if I start with m
documents, I will end up with n
chunks (where n
> m
).
So far if I do ds.take_all()
this code works fine.
However I have tried adapting the tutorial using Huggingface pipeline inference but instead of performing classification, performing named entity recognition.
Unlike classification, because a single chunk of text can have multiple (or no) entities within it, there isn’t a one to one correlation.
I want to use map_batch
with the understanding that this will be the most performant, and I’ve tried the following:
class NerModel:
def __init__(self, model_length: int = 128, batch_size:int = 16, device: str = "mps"):
# If doing CPU inference, set `device="cpu"` instead.
self.tokenizer = AutoTokenizer.from_pretrained(
"my_model",
model_max_length=model_length,
truncation=True,
max_length=model_length,
device_map=device,
)
self.pipe = pipeline(
"ner",
model="my_model",
tokenizer=self.tokenizer,
aggregation_strategy="max",
device=device,
)
self.batch_size = batch_size
def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
output = self.pipe(batch['text'], batch_size=self.batch_size)
batch = pd.DataFrame(batch).to_dict(orient='records')
outputs = []
for _output, _batch, in zip(output, batch):
for _o in _output:
row = {}
text = _batch['text']
start = _batch['start']
row["ent"] = text[_o['start']:_o['end']]
row["score"] = float(_o["score"])
row["start"] = start + _o["start"]
row["end"] = start + _o["end"]
outputs.append(row)
outputs = pd.DataFrame(outputs).to_dict(orient='list')
return outputs
However when I run this using:
ds = ds.map_batches(
NerModel,
concurrency=4,
fn_constructor_kwargs={
"model_length": 128,
"device": "mps"
},
batch_format='numpy'
)
x = ds.take_all()
I get the following error:
Traceback (most recent call last):
File "/Users/xxxxx/miniconda3/envs/my_env/lib/python3.12/site-packages/ray/data/exceptions.py", line 49, in handle_trace
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/Users/xxxxx/miniconda3/envs/my_env/lib/python3.12/site-packages/ray/data/_internal/plan.py", line 429, in execute_to_iterator
bundle_iter = itertools.chain([next(gen)], gen)
^^^^^^^^^
File "/Users/xxxxx/miniconda3/envs/my_env/lib/python3.12/site-packages/ray/data/_internal/execution/interfaces/executor.py", line 37, in __next__
return self.get_next()
^^^^^^^^^^^^^^^
File "/Users/xxxxx/miniconda3/envs/my_env/lib/python3.12/site-packages/ray/data/_internal/execution/legacy_compat.py", line 76, in get_next
bundle = self._base_iterator.get_next(output_split_idx)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/xxxxx/miniconda3/envs/my_env/lib/python3.12/site-packages/ray/data/_internal/execution/streaming_executor.py", line 153, in get_next
item = self._outer._output_node.get_output_blocking(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/xxxxx/miniconda3/envs/my_env/lib/python3.12/site-packages/ray/data/_internal/execution/streaming_executor_state.py", line 306, in get_output_blocking
raise self._exception
File "/Users/xxxxx/miniconda3/envs/my_env/lib/python3.12/site-packages/ray/data/_internal/execution/streaming_executor.py", line 232, in run
continue_sched = self._scheduling_loop_step(self._topology)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/xxxxx/miniconda3/envs/my_env/lib/python3.12/site-packages/ray/data/_internal/execution/streaming_executor.py", line 287, in _scheduling_loop_step
num_errored_blocks = process_completed_tasks(
^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/xxxxx/miniconda3/envs/my_env/lib/python3.12/site-packages/ray/data/_internal/execution/streaming_executor_state.py", line 480, in process_completed_tasks
raise e from None
File "/Users/xxxxx/miniconda3/envs/my_env/lib/python3.12/site-packages/ray/data/_internal/execution/streaming_executor_state.py", line 447, in process_completed_tasks
bytes_read = task.on_data_ready(
^^^^^^^^^^^^^^^^^^^
File "/Users/xxxxx/miniconda3/envs/my_env/lib/python3.12/site-packages/ray/data/_internal/execution/interfaces/physical_operator.py", line 105, in on_data_ready
raise ex from None
File "/Users/xxxxx/miniconda3/envs/my_env/lib/python3.12/site-packages/ray/data/_internal/execution/interfaces/physical_operator.py", line 101, in on_data_ready
ray.get(block_ref)
File "/Users/xxxxx/miniconda3/envs/my_env/lib/python3.12/site-packages/ray/_private/auto_init_hook.py", line 21, in auto_init_wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/Users/xxxxx/miniconda3/envs/my_env/lib/python3.12/site-packages/ray/_private/client_mode_hook.py", line 103, in wrapper
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/Users/xxxxx/miniconda3/envs/my_env/lib/python3.12/site-packages/ray/_private/worker.py", line 2753, in get
values, debugger_breakpoint = worker.get_objects(object_refs, timeout=timeout)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/xxxxx/miniconda3/envs/my_env/lib/python3.12/site-packages/ray/_private/worker.py", line 904, in get_objects
raise value.as_instanceof_cause()
ray.exceptions.RayTaskError(UserCodeException): ray::FlatMap(process_document)->MapBatches(NerModel)() (pid=31529, ip=127.0.0.1, actor_id=abc42dcfe08876622841480901000000, repr=MapWorker(FlatMap(process_document)->MapBatches(NerModel)))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/xxxxx/miniconda3/envs/my_env/lib/python3.12/site-packages/ray/data/_internal/execution/util.py", line 78, in __call__
return future.result()
^^^^^^^^^^^^^^^
File "/Users/xxxxx/miniconda3/envs/my_env/lib/python3.12/concurrent/futures/_base.py", line 449, in result
return self.__get_result()
^^^^^^^^^^^^^^^^^^^
File "/Users/xxxxx/miniconda3/envs/my_env/lib/python3.12/concurrent/futures/_base.py", line 401, in __get_result
raise self._exception
File "/Users/xxxxx/miniconda3/envs/my_env/lib/python3.12/concurrent/futures/thread.py", line 58, in run
result = self.fn(*self.args, **self.kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/var/folders/jj/xjj4vr5d5fd_80p1r5md82rr0000gp/T/ipykernel_19773/492767468.py", line 33, in __call__
File "/Users/xxxxx/miniconda3/envs/my_env/lib/python3.12/site-packages/transformers/pipelines/token_classification.py", line 246, in __call__
_inputs, offset_mapping = self._args_parser(inputs, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/xxxxx/miniconda3/envs/my_env/lib/python3.12/site-packages/transformers/pipelines/token_classification.py", line 42, in __call__
raise ValueError("At least one input is required.")
ValueError: At least one input is required.
Note that when I use the following code outside of ray data, all of my functions and NerModel work fine.
I’m pretty much at wits end and have no idea what I am doing wrong. If I modify NerModel to take in a single row (and modify the internal logic to return a list) and use flat_map
then there are no problems.
Please help me understand what I am doing wrong. Thank you!