How to pass objects to workers of Ray's DataParallelTrainer?

How severe does this issue affect your experience of using Ray?

  • None: Just asking a question out of curiosity
  • Low: It annoys or frustrates me for a moment.
  • Medium: It contributes to significant difficulty to complete my task, but I can work around it.
  • High: It blocks me to complete my task.

High. I’m unable to proceed with my task.

We are trying to build a training framework over Ray’s DataParallelTrainer. For that we need a way to pass an object to training_loop_per_worker. From documentations, it appears we can only pass a train_loop_config[ray/data_parallel_trainer.py at 44ef068bd1edbe3e54818c22ccd12a7dd50937f1 · ray-project/ray · GitHub] to each train function. So, is there a way to pass an object to training_loop_per_worker?

train_loop_config can be a dictionary.
you can save your object under a special key and retrieve it from your training_loop_per_worker?

@pratkpranav Following on @gjoliver advice, you could do something like this in the DataParallelTrainer

import ray
from ray.air import session

def train_loop_for_worker(config):
    dataset_shard_for_this_worker = session.get_dataset_shard("train")
    my_object = config.get("key")

    assert len(dataset_shard_for_this_worker) == 1

train_dataset = ray.data.from_items([1, 2, 3])
assert len(train_dataset) == 3

trainer = DataParallelTrainer(
     train_loop_per_worker=train_loop_per_worker,
     train_loop_config = {"key": <value>}
    ray.air.config.ScalingConfig(num_workers=3),
    datasets={"train": train_dataset},
)
result = trainer.fit()

Thanks for the reply @Jules_Damji @gjoliver. I am looking into this.