I’m trying to use the basic DistributedDataParallel (DDP) setup from PyTorch as the underlying experiment within Ray Tune. I can follow the setup as outline in the documentation here and have everything work with non-DDP PyTorch training but introducing sub-processes with torch.multiprocessing.spawn
causes results never to be reported to the main process of tuner.fit().
For some pseudo-ish code to be a bit more concrete with what’s going on, this non-DDP version works fine:
from ray import train, tune
def trainable(config: dict):
# Not shown: general setup
_trainable(config)
def _trainable(config: dict):
# Not shown: while True loop with train step, inference, etc
# eventually in the loop:
train.report({'result': ... })
tuner = tune.Tuner(trainable, param_space={...})
results = tuner.fit()
But this DDP version does not work:
from ray import train, tune
def trainable(config: dict):
# Not shown: general setup
# eventually calls:
torch.multiprocessing.spawn(
_distributed_worker,
nprocs=...,
args=(
_trainable,
config
),
daemon=False,
)
def _distributed_worker(local_rank, _trainable, config):
# Not shown: PyTorch setup, things like torch.distributed.init_process_group()
# eventually calls:
_trainable(config) # same _trainable definition as in the non-DDP version above
tuner = tune.Tuner(trainable, param_space={...})
results = tuner.fit()
The problem seems to be with the train.report()
call. I can do some debugging runs and see that train.report()
is called but it doesn’t communicate back to the main process of tuner.fit()
. It seems that the torch.multiprocessing.spawn
loses the context needed. In the non-DDP version, I can see the expected results from the Ray Context such as train.get_context().get_trial_name()
but for the DDP version it returns None
.
Any ideas on how to fix this? Maybe something equivalent to a train.set_context(context)
API with Pickle-able context information that could be passed through torch.multiprocessing.spawn
?