Tune with Function API and torch.multiprocessing.spawn

I’m trying to use the basic DistributedDataParallel (DDP) setup from PyTorch as the underlying experiment within Ray Tune. I can follow the setup as outline in the documentation here and have everything work with non-DDP PyTorch training but introducing sub-processes with torch.multiprocessing.spawn causes results never to be reported to the main process of tuner.fit().

For some pseudo-ish code to be a bit more concrete with what’s going on, this non-DDP version works fine:

from ray import train, tune

def trainable(config: dict):
    # Not shown: general setup
    _trainable(config)

def _trainable(config: dict):
    # Not shown: while True loop with train step, inference, etc
    # eventually in the loop:
    train.report({'result': ... })

tuner = tune.Tuner(trainable, param_space={...})
results = tuner.fit()

But this DDP version does not work:

from ray import train, tune

def trainable(config: dict):
    # Not shown: general setup

    # eventually calls:
    torch.multiprocessing.spawn(
        _distributed_worker,
        nprocs=...,
        args=(
            _trainable,
            config
        ),
        daemon=False,
    )

def _distributed_worker(local_rank, _trainable, config):
    # Not shown: PyTorch setup, things like torch.distributed.init_process_group()

    # eventually calls: 
    _trainable(config)  # same _trainable definition as in the non-DDP version above

tuner = tune.Tuner(trainable, param_space={...})
results = tuner.fit()

The problem seems to be with the train.report() call. I can do some debugging runs and see that train.report() is called but it doesn’t communicate back to the main process of tuner.fit(). It seems that the torch.multiprocessing.spawn loses the context needed. In the non-DDP version, I can see the expected results from the Ray Context such as train.get_context().get_trial_name() but for the DDP version it returns None.

Any ideas on how to fix this? Maybe something equivalent to a train.set_context(context) API with Pickle-able context information that could be passed through torch.multiprocessing.spawn ?