TorchTrainer hangs when only 1 worker raises error

Torch trainer hangs until the torch DDP timeout (1800 sec) before throwing an error, if one of the worker fails. Code to reproduce the error:

import torch
import ray
from ray.train.torch import TorchTrainer
from ray.air.config import ScalingConfig
from ray.air import session

def train_fn(args):
    if session.get_local_rank() > 0:
        print("Waiting for main process to do something")
        torch.distributed.barrier()
        print("we are out")
    else:
        raise ValueError
        torch.distributed.barrier()

ray.init()

use_gpu = True

trainer = TorchTrainer(
    train_fn,
    scaling_config=ScalingConfig(
        num_workers=6, use_gpu=use_gpu)
)

results = trainer.fit()

The code will be stuck for 1800 seconds before throwing any error.

Hi @saurabh3949, this is happening because each worker is executing its own training function thread, and the primary TorchTrainer aggregates the results/errors.

Could you share some more details about what your expected behavior is?

Thanks for you reply Matthew. I was hoping that the job fails as soon as one of the workers fails. This is the behavior that I get if I run the job with python -m torch.distributed.launch
Is there any way the primary TorchTrainer captures the error and surfaces it?

Oh hm my understanding was that torch.distributed.launch would also hang, but looking at torchrun docs I do see:

On failures or membership changes ALL surviving workers are killed immediately.

@saurabh3949 could you create a Github issue to request/track this change in functionality? Including the comparison script ran with torch distributed would be super helpful too!

cc @amogkam @kai

Hey @saurabh3949- this should be fixed in nightly!

The behavior now matches with torch.distributed.launch- if any worker errors, the entire training run terminates and the error is immediately raised.

Thanks a lot @amogkam !