How severe does this issue affect your experience of using Ray?
- None: Just asking a question out of curiosity
- Low: It annoys or frustrates me for a moment.
- Medium: It contributes to significant difficulty to complete my task, but I can work around it.
- High: It blocks me to complete my task.
High. I’m unable to proceed with my task.
We are trying to build a training framework over Ray’s DataParallelTrainer. For that we need a way to pass an object to training_loop_per_worker. From documentations, it appears we can only pass a train_loop_config[ray/python/ray/train/data_parallel_trainer.py at 44ef068bd1edbe3e54818c22ccd12a7dd50937f1 · ray-project/ray · GitHub] to each train function. So, is there a way to pass an object to training_loop_per_worker?
train_loop_config can be a dictionary.
you can save your object under a special key and retrieve it from your training_loop_per_worker?
@pratkpranav Following on @gjoliver advice, you could do something like this in the DataParallelTrainer
import ray
from ray.air import session
def train_loop_for_worker(config):
dataset_shard_for_this_worker = session.get_dataset_shard("train")
my_object = config.get("key")
assert len(dataset_shard_for_this_worker) == 1
train_dataset = ray.data.from_items([1, 2, 3])
assert len(train_dataset) == 3
trainer = DataParallelTrainer(
train_loop_per_worker=train_loop_per_worker,
train_loop_config = {"key": <value>}
ray.air.config.ScalingConfig(num_workers=3),
datasets={"train": train_dataset},
)
result = trainer.fit()
Thanks for the reply @Jules_Damji @gjoliver. I am looking into this.