Ray Dataset with Distributed PyTorch

Hi ,

I am working on a distributed PyTorch pipeline with Ray Dataset, I just have the question below:

After ds = ray.data.read_parquet("path", parallelism=20), then apply to_torch(), can I think the PyTorch dataset is a distributed dataset or just the same as usual one like torchvision.datasets.CIFAR10(‘dir’)`?

Thanks!

Hi @sgwhat, thanks for posting! I just noticed that this post never got a response.

Yes, if you ran ds = ray.data.read_parquet() on a multi-node cluster, read tasks (and therefore your data) will be spread across the cluster. Therefore, all of the data under-the-hood of the returned Torch Dataset will be distributed, with the data pulled to the consumer (trainer) as you iterate over the dataset during training.