How to use ray data to transform rows into batch objects?

I’m loving ray.data and its distribution and lazy mapping for transformation; it’s great! But it’s early days and I’m having trouble with a step… I think I want almost the reverse of flat_map, a transform which takes in multiple rows and outputs a single row, which isn’t what map_batches does (when I tried it).

Longer explanation:

The model we’re using operates on objects representing a batch of data (this is distinct from “a list or array of data objects” here). So I can use ray.data like this for evaluation (mild pseudocode):

data = (
  ray.data.from_items(data_items)
  .map(transform1)
  .map(transform2)
)
# but now I have to step away and manually iterate over the batches
batch_objects = [create_batch_object(batch) for batch in data.iter_batches(batch_size)]
# and call the model with batch objects
results = [model(bobj) for bobj in batch_objects]
# and then flatten the results and unpack them
result_batches = flatten([unpack_batch_object(bobj) for bobj in results])

I can’t find a Dataset function that aggregates multiple rows in batches, just the take_batches or iter_batches which pulls me out of the lazy flow. I would prefer something like

data = (
  ray.data.from_items(data_items)
  .map(transform1)
  .map(transform2)
  .batch_map(fn_taking_multiple_rows_and_returning_one_row)
  .map(model)
  .flat_map(unpack_batch)
)

where .batch_map “maps over batches”, taking a sequence of rows and returning one row.

Maybe this exists and despite blearily scanning the docs I just missed it; I thought map_batches would do this but it seems to map rows, just a batch at a time for better vectorization.

Does this exist?