[Ray Data] Need a custom ray.data.aggregate.AggregateFn to sum over numpy arrays

Let’s say I have a map function get_conf_matrix that can compute confusion matrix for each row, now I want to aggregate all the confusion matrices globally. The built-in sum only works for a scalar value. How can I get a custom AggregateFn that can apply summation over numpy arrays?

Basically I need a function to execute the following two lines.

conf_matrix_dataset = input.map(get_conf_matrix)
total_conf_matrix = conf_matrix_dataset.aggregate(ray.data.aggregate.NpSum("conf_matrix"))

As a workaround, I also tried to use flat_map to multiple integer columns and then apply sum to each column, but then I got the following error:

“Ray Data requires schemas for all datasets in Ray 2.5. This means that standalone Python objects are no longer supported. In addition, the default batch format is fixed to NumPy. To revert to legacy behavior temporarily, set the environment variable RAY_DATA_STRICT_MODE=0 on all cluster processes.”

# Assume that label_index, prediction_index, num_classes are all provided.

def get_conf_matrix(row: pd.Series):
        return [
            (f"conf_matrix[{i}, {j}]", 1 if i == label_index and j == prediction_index else 0)
            for i in range(num_classes)
            for j in range(num_classes)
        ]

flattened_conf_matrix_field_names = [f"conf_matrix[{i}, {j}]" for i in range(num_classes) for j in range(num_classes)]

return input.flat_map(get_conf_matrix).sum(flattened_conf_matrix_field_names)

Note: I am new to Ray Data, and I understand that map_batches is recommended here. I just want to start simple and implement a map/flat_map based version first. Thanks!

Hey @wayi, what happened when you tried implementing a custom AggregateFn?

This error just means you need to return a list[dict[str, Any]] from flat_map (each row must be a dict[str, Any]).

Thanks for the explanation! Interesting to know that the error message refers to this common error.

The question was asked quite a while ago. I’ve got a workaround without using the same operators.