Hi, I am facing the following problem
I am creating a dict of DataSet
and I want a model to run on each specific dataset (so not just randomly taking n batch of my original data)
As far as I understood, once I have this dict of dataSets
I can specify the split I want in DataParallelTrainer
but then in my looping function I am not sure how I can access the right dataSet ? (for each worker)
def train_loop_for_worker():
dataset_shard_for_this_worker = train.get_dataset_shard(?) # I would be able to get the data from each key of data_splitted here, basically trying to resolve the input of get_dataset_shard dynamically
# do stuff with the data here
trainer = DataParallelTrainer(
train_loop_for_worker,
scaling_config=ScalingConfig(num_workers=3),
dataset_config=DataConfig(datasets_to_split=list(data_splitted.keys())), # data_splitted is a dict of DataSets I want to run train_loop_for_worker onto
datasets=data_splitted,
)
result = trainer.fit()
I would like to basically have workers running model on each group I defined in the dataset_config, i.e having a way to dynamically change the input in train.get_dataset_shard()
but I am not sure how to achieve that.
If there is a better way to achieve that I am open to suggestions