We’re currently using ray for data transformation & retrieval. Very excited about the potential for ray datasets, however, the performance that we’re seeing (particularly when retrieving data) is poor in comparison to other methods.
My question is do we expect async wait/get functionality to be implemented in ray datasets somehow in order to improve performance? Is there anything that I should be doing to improve the speed of data retrieval when using ray datasets?
I’ve included a script that compares the performance of 5 different ways of getting data back from ray. The request reads in 90 different partitions (each is a 5MB parquet file) and returns the data as a list of pyarrow tables. The 5 methods are:
- ray dataset
- ray.get (synchronous actor)
- ray.get (asynchronous actor)
- ray wait/get (synchronous actor)
- ray wait/get (asynchronous actor)
Run times are:
Script:
import ray
import pandas as pd
import pyarrow.parquet as pq
import time
import asyncio
@ray.remote
class SyncParquetReadActor:
def __init__(self) -> None:
pass
def read(self, path):
return pq.read_table(path)
@ray.remote
class AsyncParquetReadActor:
def __init__(self) -> None:
pass
async def _read(self, path):
return pq.read_table(path)
async def read(self, path):
return await self._read(path)
@staticmethod
async def wait_get(results):
output = []
while len(results):
done_id, results = await asyncio.wait(results)
output.append(await asyncio.gather(*done_id))
return output
def get_dataset(paths):
return ray.data.read_parquet(paths).to_arrow()
if __name__ == '__main__':
ray.init(f'ray://<RAY ADDRESS>:10001')
base_path = '/<DATA PATH>/CloseDate={}/aam.parquet'
dates = pd.date_range('2021-01-01', '2021-03-31', freq='D').date
paths = []
for dt in dates:
paths.append(base_path.format(dt.strftime('%Y-%m-%d')))
st = time.time()
output = get_dataset(paths)
et = time.time()
print(f'Ray dataset run time: {et - st}')
st = time.time()
results = [SyncParquetReadActor.remote().read.remote(path) for path in paths]
output = ray.get(results)
et = time.time()
print(f'Sync. ray.get run time: {et - st}')
st = time.time()
results = [AsyncParquetReadActor.remote().read.remote(path) for path in paths]
output = ray.get(results)
et = time.time()
print(f'Async. ray.get run time: {et - st}')
st = time.time()
results = [SyncParquetReadActor.remote().read.remote(path) for path in paths]
output = []
while len(results):
done_id, results = ray.wait(results)
output.append(ray.get(done_id))
et = time.time()
print(f'Sync. ray wait/get run time: {et - st}')
st = time.time()
results = [AsyncParquetReadActor.remote().read.remote(path) for path in paths]
output = asyncio.run(AsyncParquetReadActor.wait_get(results))
et = time.time()
print(f'Async. ray wait/get run time: {et - st}')