Torch trainer hangs until the torch DDP timeout (1800 sec) before throwing an error, if one of the worker fails. Code to reproduce the error:
import torch
import ray
from ray.train.torch import TorchTrainer
from ray.air.config import ScalingConfig
from ray.air import session
def train_fn(args):
if session.get_local_rank() > 0:
print("Waiting for main process to do something")
torch.distributed.barrier()
print("we are out")
else:
raise ValueError
torch.distributed.barrier()
ray.init()
use_gpu = True
trainer = TorchTrainer(
train_fn,
scaling_config=ScalingConfig(
num_workers=6, use_gpu=use_gpu)
)
results = trainer.fit()
The code will be stuck for 1800 seconds before throwing any error.
Hi @saurabh3949, this is happening because each worker is executing its own training function thread, and the primary TorchTrainer aggregates the results/errors.
Could you share some more details about what your expected behavior is?
Thanks for you reply Matthew. I was hoping that the job fails as soon as one of the workers fails. This is the behavior that I get if I run the job with python -m torch.distributed.launch
Is there any way the primary TorchTrainer captures the error and surfaces it?
Oh hm my understanding was that torch.distributed.launch would also hang, but looking at torchrun docs I do see:
On failures or membership changes ALL surviving workers are killed immediately.
@saurabh3949 could you create a Github issue to request/track this change in functionality? Including the comparison script ran with torch distributed would be super helpful too!
I feel it has something to do with NCCL. If you comment out the line that wraps the model in DDP, the ray job immediately throws the error.
It gets stuck otherwise.