Dear Ray community,
I am struggling hard optimizing my code. It is composed of:
FeatureStore
: a Python class that handles requests of K numpy ids, and returns a Pyarrow Table with K rows and D columns with mix data types.- N
FeatureStorePartition
actors, spread over all the cluster, that hold the data, which is inside a Pyarrow Table with an ID column and other D columns with mix data types.
@ray.remote(
scheduling_strategy="SPREAD",
runtime_env={"pip": ["pyarrow>=11.0.0"]},
)
class FeatureStorePartition:
features: pa.Table
def __init__(self, data):
self.features = data
def get_features(
self, indices: pa.Table
) -> pa.Table:
mask = pc.is_in(feats[node_id_col], value_set=indices)
selected = feats.filter(mask).combine_chunks()
return selected
(code simplified for this post)
Every time the FeatureStore
receives requests it calls the actors in parallel i.e.
futures = [actor.get_features.remote(indexes_for_actor[i]) for i, actor in enumerate(actors)])
see whether they have the data and make them return the data. I have N remore calls, which return Pyarrow tables to the FeatureStore.
features = pa.concat_tables(ray.get(futures))
Now, I observed that the code above that runs ray.get is very slow and it takes ~ 0.42 seconds for requests of D = 1024 and K = ~ 2K and a cluster with 5 workers. Each FeatureStorePartition
actor holds ~ 500K rows. I also realized that get_features
takes only 0.03 seconds on average, thus I assume network calls are slow. Is there a way to optimize the code above? Or is there a way to have something like I made, which resembles a database, with very low latency?
thanks