How to use ray data to transform rows into batch objects?

I’m loving ray.data and its distribution and lazy mapping for transformation; it’s great! But it’s early days and I’m having trouble with a step… I think I want almost the reverse of flat_map, a transform which takes in multiple rows and outputs a single row, which isn’t what map_batches does (when I tried it).

Longer explanation:

The model we’re using operates on objects representing a batch of data (this is distinct from “a list or array of data objects” here). So I can use ray.data like this for evaluation (mild pseudocode):

data = (
  ray.data.from_items(data_items)
  .map(transform1)
  .map(transform2)
)
# but now I have to step away and manually iterate over the batches
batch_objects = [create_batch_object(batch) for batch in data.iter_batches(batch_size)]
# and call the model with batch objects
results = [model(bobj) for bobj in batch_objects]
# and then flatten the results and unpack them
result_batches = flatten([unpack_batch_object(bobj) for bobj in results])

I can’t find a Dataset function that aggregates multiple rows in batches, just the take_batches or iter_batches which pulls me out of the lazy flow. I would prefer something like

data = (
  ray.data.from_items(data_items)
  .map(transform1)
  .map(transform2)
  .batch_map(fn_taking_multiple_rows_and_returning_one_row)
  .map(model)
  .flat_map(unpack_batch)
)

where .batch_map “maps over batches”, taking a sequence of rows and returning one row.

Maybe this exists and despite blearily scanning the docs I just missed it; I thought map_batches would do this but it seems to map rows, just a batch at a time for better vectorization.

Does this exist?

I’m in the same situation as you, including spending way too long reading the docs and hoping to find this functionality. You can implement it using the groupby() and map_groups() functionality.

For my task, I want to create dynamic numbers of records per batch, based on text length. So I created a batcher and applied with map_batches. This added a unique batch_id to each record. Then, use groupby() to sort by batch_id, and map_groups to apply the next stage to the batches.

After the transform, I used another sort stage to undo the result.

1 Like

@jlquinn would you be willing to contribute a docs article PR on this; I’ve seen this from the community on multiple occasions.

If you link me at #ray-contributors on Ray Slack I’ll help shepherd it through with the relevant committers/folks with merge to master rights.

TPM @ Anyscale here.

@Sam_Chan Sure, I’d be open to doing that. I’ll get onto slack soon.

1 Like