Hi
I noticed that _collate_fn
argument was removed from ray.data.DataIterator.iter_batches
in 2.47.0
release.
My understanding was that ray.train.get_dataset_shard
should be used to iterate through batches of data when number of trainers are more than one.
I also noticed that the documentation still refers to use of collage_fn
in iter_batches
I’m sure there must be a reason, but now it breaks my workflow.
I am not sure if there is any other way to use collate_fn without changing my dataset inferface.
Can someone help please?
2.48.0 API reference
2.47.0 API reference
I’m using 2.48.0 and testing it locally.
FYI I implemented my custom dataset and a training script that uses following pattern.
train_dataset = ray.data.read_parquet(dataset_info_file)
train_dataset = train_dataset.map(ReadDataset, concurrency=4, num_cpus=4)
###### training function ######
def train_loop_per_worker(config):
...
...
def collate_fn(batch):
...
...
return train_data, train_target, train_metadata
train_data_shard = ray.train.get_dataset_shard("train")
train_dataloader = train_data_shard.iter_batches(batch_size=batch_size, _collate_fn=collate_fn, prefetch_batches=True, drop_last=True)
for epoch in range(start_epoch, num_epochs):
...
...
for data, targets, metadata in train_dataloader:
...
...
Thank you.