Ray parquet data streaming causing massive spill and storage issues

1. Severity of the issue: (select one)

None: I’m just curious or want clarification.
Low: Annoying but doesn’t hinder my work.
Medium: Significantly affects my productivity but can find a workaround.
High: Completely blocks me.

2. Environment:

  • Ray version: 2.48.0
  • Python version: 3.11.13
  • OS:
  • Cloud/Infrastructure: AWS
  • Other libs/tools (if relevant):

3. What happened vs. what you expected:

  • Expected:Ray data to clean up the data
  • Actual:It isn’t it’s cleaning up

We create the ray dataset by passing s3 paths.
def _create_ray_dataset(self, parquet_paths: List[str]):
ray_ds = ray.data.read_parquet(parquet_paths,parallelism=2)
return ray_ds

The resultant dataset is passed to the following debugging snippet below which produces our issue. The idea is to do actual processing on the row when we resolve the spill issue

class StreamingParquetDataset(IterableDataset):
   def __init__(
        self,
        dataset: Iterable,
        tokenizer_name: Optional[str] = "",
        max_length: int = 512,
        pack_sequences: bool = True,
        text_preprocessing_fn: Optional[callable] = None
      
    ):
       self.dataset = dataset
        self.max_length = max_length
        self.pack_sequences = pack_sequences
        self.text_preprocessing_fn = text_preprocessing_fn
        self.is_val = is_val

        # Initialize tokenizer if needed
        self.tokenizer = None
        if tokenizer_name:
            self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
            if self.tokenizer.pad_token is None:
                self.tokenizer.pad_token = self.tokenizer.eos_token

   def __iter__(self):   
        #for row in self.dataset.iter_rows():
        for i, row in enumerate(self.dataset.iter_batches(batch_size=1)):
            del row
            tensor_ones_input_ids = torch.ones( self.max_length,dtype=torch.long)
            tensor_ones_attention_mask = torch.ones( self.max_length,dtype=torch.long)
            tensor_ones_labels= torch.ones(self.max_length,dtype=torch.long)
            result = {
                        "input_ids": tensor_ones_input_ids,
                        "attention_mask": tensor_ones_attention_mask,
                        "labels": tensor_ones_labels
                    }
            yield result
       

 

We manually added the del arg and still the spilling happens.

Any pointers on how we can resolve this