`map_batches` fails with Huggingface NER pipeline

I am really enjoying ray data, however have run into an issue which is surprisingly difficult for me to figure out the solution for.

I am trying to perform inference on documents to perform entity extraction. The data comes from postgres table. To do this I do something like:

ds = ray.data.read_sql("SELECT * FROM document_table;")

I then perform semantic chunking:

ds = ds.flat_map(process_document)

Where process_document takes in a row from document_table and then splits it into chunks, and then for each chunk keeps the original document’s meta data with it. I use flat_map here because if I start with m documents, I will end up with n chunks (where n > m).

So far if I do ds.take_all() this code works fine.

However I have tried adapting the tutorial using Huggingface pipeline inference but instead of performing classification, performing named entity recognition.

Unlike classification, because a single chunk of text can have multiple (or no) entities within it, there isn’t a one to one correlation.

I want to use map_batch with the understanding that this will be the most performant, and I’ve tried the following:

class NerModel:
    def __init__(self, model_length: int = 128, batch_size:int = 16, device: str = "mps"):
        # If doing CPU inference, set `device="cpu"` instead.
        self.tokenizer = AutoTokenizer.from_pretrained(
            "my_model",
            model_max_length=model_length,
            truncation=True,
            max_length=model_length,
            device_map=device,
        )

        self.pipe = pipeline(
            "ner",
            model="my_model",
            tokenizer=self.tokenizer,
            aggregation_strategy="max",
            device=device,
        )

        self.batch_size = batch_size

    def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:

        output = self.pipe(batch['text'], batch_size=self.batch_size)

        batch = pd.DataFrame(batch).to_dict(orient='records')

        outputs = []
        for _output, _batch, in zip(output, batch):

            for _o in _output:

                row = {}
                text = _batch['text']
                start = _batch['start']
                
                row["ent"] = text[_o['start']:_o['end']]
                row["score"] = float(_o["score"])
                row["start"] = start + _o["start"]
                row["end"] = start + _o["end"]
                
                outputs.append(row)

        outputs = pd.DataFrame(outputs).to_dict(orient='list')
        
        return outputs

However when I run this using:

ds = ds.map_batches(
        NerModel,
        concurrency=4,
        fn_constructor_kwargs={
            "model_length": 128,
            "device": "mps"
        },
        batch_format='numpy'
    )

x = ds.take_all()

I get the following error:

Traceback (most recent call last):
  File "/Users/xxxxx/miniconda3/envs/my_env/lib/python3.12/site-packages/ray/data/exceptions.py", line 49, in handle_trace
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/Users/xxxxx/miniconda3/envs/my_env/lib/python3.12/site-packages/ray/data/_internal/plan.py", line 429, in execute_to_iterator
    bundle_iter = itertools.chain([next(gen)], gen)
                                   ^^^^^^^^^
  File "/Users/xxxxx/miniconda3/envs/my_env/lib/python3.12/site-packages/ray/data/_internal/execution/interfaces/executor.py", line 37, in __next__
    return self.get_next()
           ^^^^^^^^^^^^^^^
  File "/Users/xxxxx/miniconda3/envs/my_env/lib/python3.12/site-packages/ray/data/_internal/execution/legacy_compat.py", line 76, in get_next
    bundle = self._base_iterator.get_next(output_split_idx)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/xxxxx/miniconda3/envs/my_env/lib/python3.12/site-packages/ray/data/_internal/execution/streaming_executor.py", line 153, in get_next
    item = self._outer._output_node.get_output_blocking(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/xxxxx/miniconda3/envs/my_env/lib/python3.12/site-packages/ray/data/_internal/execution/streaming_executor_state.py", line 306, in get_output_blocking
    raise self._exception
  File "/Users/xxxxx/miniconda3/envs/my_env/lib/python3.12/site-packages/ray/data/_internal/execution/streaming_executor.py", line 232, in run
    continue_sched = self._scheduling_loop_step(self._topology)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/xxxxx/miniconda3/envs/my_env/lib/python3.12/site-packages/ray/data/_internal/execution/streaming_executor.py", line 287, in _scheduling_loop_step
    num_errored_blocks = process_completed_tasks(
                         ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/xxxxx/miniconda3/envs/my_env/lib/python3.12/site-packages/ray/data/_internal/execution/streaming_executor_state.py", line 480, in process_completed_tasks
    raise e from None
  File "/Users/xxxxx/miniconda3/envs/my_env/lib/python3.12/site-packages/ray/data/_internal/execution/streaming_executor_state.py", line 447, in process_completed_tasks
    bytes_read = task.on_data_ready(
                 ^^^^^^^^^^^^^^^^^^^
  File "/Users/xxxxx/miniconda3/envs/my_env/lib/python3.12/site-packages/ray/data/_internal/execution/interfaces/physical_operator.py", line 105, in on_data_ready
    raise ex from None
  File "/Users/xxxxx/miniconda3/envs/my_env/lib/python3.12/site-packages/ray/data/_internal/execution/interfaces/physical_operator.py", line 101, in on_data_ready
    ray.get(block_ref)
  File "/Users/xxxxx/miniconda3/envs/my_env/lib/python3.12/site-packages/ray/_private/auto_init_hook.py", line 21, in auto_init_wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/Users/xxxxx/miniconda3/envs/my_env/lib/python3.12/site-packages/ray/_private/client_mode_hook.py", line 103, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/xxxxx/miniconda3/envs/my_env/lib/python3.12/site-packages/ray/_private/worker.py", line 2753, in get
    values, debugger_breakpoint = worker.get_objects(object_refs, timeout=timeout)
                                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/xxxxx/miniconda3/envs/my_env/lib/python3.12/site-packages/ray/_private/worker.py", line 904, in get_objects
    raise value.as_instanceof_cause()
ray.exceptions.RayTaskError(UserCodeException): ray::FlatMap(process_document)->MapBatches(NerModel)() (pid=31529, ip=127.0.0.1, actor_id=abc42dcfe08876622841480901000000, repr=MapWorker(FlatMap(process_document)->MapBatches(NerModel)))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/xxxxx/miniconda3/envs/my_env/lib/python3.12/site-packages/ray/data/_internal/execution/util.py", line 78, in __call__
    return future.result()
           ^^^^^^^^^^^^^^^
  File "/Users/xxxxx/miniconda3/envs/my_env/lib/python3.12/concurrent/futures/_base.py", line 449, in result
    return self.__get_result()
           ^^^^^^^^^^^^^^^^^^^
  File "/Users/xxxxx/miniconda3/envs/my_env/lib/python3.12/concurrent/futures/_base.py", line 401, in __get_result
    raise self._exception
  File "/Users/xxxxx/miniconda3/envs/my_env/lib/python3.12/concurrent/futures/thread.py", line 58, in run
    result = self.fn(*self.args, **self.kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/var/folders/jj/xjj4vr5d5fd_80p1r5md82rr0000gp/T/ipykernel_19773/492767468.py", line 33, in __call__
  File "/Users/xxxxx/miniconda3/envs/my_env/lib/python3.12/site-packages/transformers/pipelines/token_classification.py", line 246, in __call__
    _inputs, offset_mapping = self._args_parser(inputs, **kwargs)
                              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/xxxxx/miniconda3/envs/my_env/lib/python3.12/site-packages/transformers/pipelines/token_classification.py", line 42, in __call__
    raise ValueError("At least one input is required.")
ValueError: At least one input is required.

Note that when I use the following code outside of ray data, all of my functions and NerModel work fine.

I’m pretty much at wits end and have no idea what I am doing wrong. If I modify NerModel to take in a single row (and modify the internal logic to return a list) and use flat_map then there are no problems.

Please help me understand what I am doing wrong. Thank you!