Ray Data: How to yield entire groups from a batch?

I want to use Ray Data for a sequential data problem. By taking one batch of data, I want all the data associated with the group. So if I take 32 batches, I get all the data for 32 different items that I grouped by.

Here is a rough example:


def to_group(group: pd.DataFrame) -> pd.DataFrame:
    return group  

ds = ray.data.read_parquet("s3://anonymous@ray-example-data/iris.parquet") \
        .groupby('variety') \
        .map_groups(to_group)

x = ds.take(1)  # by taking 1 batch, I want all Virginica, next batch I want all Setosa, etc
x = ds.take_batch(1)

Looking for alternatives, particularly from Torch, I noticed that the documentation here seems to be misleading as it suggests it can support IterableDataset.

https://docs.ray.io/en/latest/data/api/doc/ray.data.from_torch.html

    import torch  
    from torch.utils.data import IterableDataset  
    import random  
    import ray
    
    class FakeDataIterableDataset(IterableDataset):  
        def __init__(self, num_samples):  
            self.num_samples = num_samples  
    
        def __iter__(self):  
            for _ in range(self.num_samples):  
                features = torch.tensor([random.random() for _ in range(3)])  
                label = torch.tensor([random.randint(0, 1)])  
                yield features, label  
    
    num_samples = 1000  
    dataset = FakeDataIterableDataset(num_samples)  
    ds = ray.data.from_torch(dataset)
    ds.take(1)

TypeError: object of type 'FakeDataIterableDataset' has no len()

cc: @bveeramani any insight here?

Hey @localh, I don’t think there’s an easy way to do that right now. With the way map_groups is implemented, the data doesn’t really remain grouped after the map_groups

thanks @bveeramani for a quick response. I did not know that. Cheers!

@localh I’ll close this issue since the current implementation does not provide an easy and intuitive way

cc: @bveeramani