Ray Data Map batches performance optimization

How severe does this issue affect your experience of using Ray?

  • Medium: It contributes to significant difficulty to complete my task, but I can work around it.

We are building a feature engineering pipeline using Ray Data for transaction data with approximately 4.2 million rows (the actual production data will be in the billions). The pipeline involves creating custom functions for distance calculation, window-based aggregation features, and other steps.

The pipeline combines map_batches and map_groups operations over customer IDs. Below is an overview of the steps:

def make_features(rayDF):
    decimal_cols = get_decimal_cols(rayDF)
    rayDF1 = rayDF.map_batches(transform_dtype, batch_format="pandas", fn_kwargs={'decimal_cols': decimal_cols})
    # Efficient way to have limited map groups
    rayDF1 = rayDF1.add_column("user_cluster", lambda df: assign_clusters(df["CUSTOMER_XID_ENC"]))
    # Imputing missing data via map_groups
    rayDF1 = drop_and_impute(rayDF1)
    print("distance calculation start")
    rayDF1 = rayDF1.map_batches(get_distance_metrics, batch_format="pandas")
    print("distance calculation end")
    # Window aggregation of Transactions amount using groupby("user_cluster").map_groups function
    print("get_window_features start")
    rayDF1 = get_window_features(rayDF1, window_list = [1, 7, 30, 35])
    print("get_window_features end")
    rayDF1 = fix_datetimes(rayDF1)
    print("fix_datetimes end")
    return rayDF1


rayDF1 = make_features(rayDF)

Calling this function alone takes almost 2-3 minutes before performing the actual computation. Do you have any suggestions on how to reduce this time, considering that we are not doing any actual operation here? Or am I missing something?

After that, performing the write actions (rayDF1.write_parquet(output_file_path)) takes another 3 minutes.

Can you share the drop_and_impute, get_window_features, and fix_datetimes functions? The code section shared doesn’t look like it should be triggering execution.

You are right there is no actual execution. Below are the snippets for the different functions:

Here we are imputing the missing data:

def fill_first_valid_index(group):

    group['MER_ID'] = group['MER_ID'].replace('', 'Unknown')
    group['MER_ZIP'] = group['MER_ZIP'].replace('', None)
    group['CUS_ZIP'] = group['CUS_ZIP'].replace('', None)

    group["CUS_ZIP"] = group.groupby("CUSTOMER_XID_ENC")["CUS_ZIP"].transform(lambda x: x.iloc[0])
    group["MER_ID"] = group.groupby("MER_ID")["MER_ID"].transform(lambda x: x.iloc[0])

    group['CUS_ZIP_f'] = np.where(group['CUS_ZIP'].notna(), group['CUS_ZIP'].str[:5], None)
    group['MER_ZIP_f'] = np.where(group['MER_ZIP'].notna(), group['MER_ZIP'].str[:5], None)

    group['CUS_ZIP_3']=group['CUS_ZIP_3'].fillna("")
    group['CUS_STATE']=group['CUS_STATE'].fillna("")
    
    return group

def drop_and_impute(rayDF):
    rayDF = rayDF.drop_columns(['CUSTOMER_XID','CUSTOMER_XID_HASH','ACCT_NBR_HASH'])
    rayDF = rayDF.groupby("user_cluster").map_groups(fill_first_valid_index, batch_format="pandas")
    return rayDF

def get_window_features(rayDF, window_list = [1, 7, 30, 35]):

    rayDF = rayDF.add_column("Fraud_TRN_AMT", lambda df: (df['TRN_AMT'] * df['IS_FRAUD']) )

    agg_params_list = [
    {
        'new_col_prefix': "Trn_Amt_",
        'window_list': [1, 7, 30, 35],
        'group_by': "CUSTOMER_XID_ENC",
        'agg_col': "TRN_AMT",
        'agg_func': "sum",
        'order_by_col': "TRN_DT"
    },
    {
        'new_col_prefix': 'Mer_FRD_Trn_Amt_',
        'window_list': [1, 7, 30, 35],
        'group_by': "MER_ID",
        'agg_col': "Fraud_TRN_AMT",
        'agg_func': "sum",
        'order_by_col': "TRN_DT"
    },
    ]

    batch_format = "pandas"
    # window_agg is a function to do actual window aggregations either using polars/pandas dataframe
    rayDF = rayDF.groupby("user_cluster").map_groups(window_agg, fn_kwargs={"fn_kwargs_list": agg_params_list,
                                                                            "backend":"polars",
                                                                            "batch_format":batch_format}, batch_format=batch_format)
    return rayDF