Custom data sharing in DataParallelTrainer

Hi, I am facing the following problem

I am creating a dict of DataSet and I want a model to run on each specific dataset (so not just randomly taking n batch of my original data)

As far as I understood, once I have this dict of dataSets I can specify the split I want in DataParallelTrainer

but then in my looping function I am not sure how I can access the right dataSet ? (for each worker)

def train_loop_for_worker():

    dataset_shard_for_this_worker = train.get_dataset_shard(?) # I would be able to get the data from each key of data_splitted here, basically trying to resolve the input of get_dataset_shard dynamically
   
    # do stuff with the data here 

    
trainer = DataParallelTrainer(
    train_loop_for_worker,
    scaling_config=ScalingConfig(num_workers=3),
    dataset_config=DataConfig(datasets_to_split=list(data_splitted.keys())), # data_splitted is a dict of DataSets I want to run train_loop_for_worker onto 
    datasets=data_splitted,
)

result = trainer.fit()

I would like to basically have workers running model on each group I defined in the dataset_config, i.e having a way to dynamically change the input in train.get_dataset_shard() but I am not sure how to achieve that.

If there is a better way to achieve that I am open to suggestions

@guedojulie you can fetch a specific dataset by key, e.g., train.get_dataset_shard("my_dataset_name"). These keys are the same as those in the data_splitted dict you passed to the Trainer earlier.

For more info, check out ray.train.get_dataset_shard — Ray 2.10.0