1. Severity of the issue: (select one)
None: I’m just curious or want clarification.
Low: Annoying but doesn’t hinder my work.
Medium: Significantly affects my productivity but can find a workaround.
High: Completely blocks me.
2. Environment:
- Ray version: 2.48.0
- Python version: 3.11.13
- OS:
- Cloud/Infrastructure: AWS
- Other libs/tools (if relevant):
3. What happened vs. what you expected:
- Expected:Ray data to clean up the data
- Actual:It isn’t it’s cleaning up
We create the ray dataset by passing s3 paths.
def _create_ray_dataset(self, parquet_paths: List[str]):
ray_ds = ray.data.read_parquet(parquet_paths,parallelism=2)
return ray_ds
The resultant dataset is passed to the following debugging snippet below which produces our issue. The idea is to do actual processing on the row when we resolve the spill issue
class StreamingParquetDataset(IterableDataset):
def __init__(
self,
dataset: Iterable,
tokenizer_name: Optional[str] = "",
max_length: int = 512,
pack_sequences: bool = True,
text_preprocessing_fn: Optional[callable] = None
):
self.dataset = dataset
self.max_length = max_length
self.pack_sequences = pack_sequences
self.text_preprocessing_fn = text_preprocessing_fn
self.is_val = is_val
# Initialize tokenizer if needed
self.tokenizer = None
if tokenizer_name:
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
def __iter__(self):
#for row in self.dataset.iter_rows():
for i, row in enumerate(self.dataset.iter_batches(batch_size=1)):
del row
tensor_ones_input_ids = torch.ones( self.max_length,dtype=torch.long)
tensor_ones_attention_mask = torch.ones( self.max_length,dtype=torch.long)
tensor_ones_labels= torch.ones(self.max_length,dtype=torch.long)
result = {
"input_ids": tensor_ones_input_ids,
"attention_mask": tensor_ones_attention_mask,
"labels": tensor_ones_labels
}
yield result
We manually added the del arg and still the spilling happens.
Any pointers on how we can resolve this