TorchTrainer hangs when only 1 worker raises error

Torch trainer hangs until the torch DDP timeout (1800 sec) before throwing an error, if one of the worker fails. Code to reproduce the error:

import torch
import ray
from ray.train.torch import TorchTrainer
from ray.air.config import ScalingConfig
from ray.air import session

def train_fn(args):
    if session.get_local_rank() > 0:
        print("Waiting for main process to do something")
        torch.distributed.barrier()
        print("we are out")
    else:
        raise ValueError
        torch.distributed.barrier()

ray.init()

use_gpu = True

trainer = TorchTrainer(
    train_fn,
    scaling_config=ScalingConfig(
        num_workers=6, use_gpu=use_gpu)
)

results = trainer.fit()

The code will be stuck for 1800 seconds before throwing any error.

Hi @saurabh3949, this is happening because each worker is executing its own training function thread, and the primary TorchTrainer aggregates the results/errors.

Could you share some more details about what your expected behavior is?

Thanks for you reply Matthew. I was hoping that the job fails as soon as one of the workers fails. This is the behavior that I get if I run the job with python -m torch.distributed.launch
Is there any way the primary TorchTrainer captures the error and surfaces it?

Oh hm my understanding was that torch.distributed.launch would also hang, but looking at torchrun docs I do see:

On failures or membership changes ALL surviving workers are killed immediately.

@saurabh3949 could you create a Github issue to request/track this change in functionality? Including the comparison script ran with torch distributed would be super helpful too!

cc @amogkam @kai

Hey @saurabh3949- this should be fixed in nightly!

The behavior now matches with torch.distributed.launch- if any worker errors, the entire training run terminates and the error is immediately raised.

Thanks a lot @amogkam !

@amogkam Ray still hangs if I initialize a DDP model that uses NCCL. The issue can be replicated using the following code:

import sys
import argparse
import torch
import logging
from torch.nn.parallel import DistributedDataParallel as DDP

logger = logging.getLogger()


class TinyModel(torch.nn.Module):

    def __init__(self):
        super(TinyModel, self).__init__()

        self.linear1 = torch.nn.Linear(100, 200)
        self.activation = torch.nn.ReLU()
        self.linear2 = torch.nn.Linear(200, 10)
        self.softmax = torch.nn.Softmax()

    def forward(self, x):
        x = self.linear1(x)
        x = self.activation(x)
        x = self.linear2(x)
        x = self.softmax(x)
        return x


def save_model(args):
    if args.local_rank == 0:
        raise ValueError
    else:
        print("Waiting for rank 0 to raise error")

    torch.distributed.barrier()


def main(args):
    print("args from worker", args)
    parser = argparse.ArgumentParser()
    parser.add_argument("--local_rank", type=int)
    args, _ = parser.parse_known_args(args)

    device_id = args.local_rank
    device = torch.device("cuda", device_id)
    n_gpu = torch.cuda.device_count()
    torch.cuda.set_device(device)
    tinymodel = TinyModel()
    tinymodel.to(device)
    ddp_model = DDP(tinymodel, device_ids=[args.local_rank], output_device=args.local_rank)

    save_model(args)


if __name__ == '__main__':
    args = sys.argv[1:]
    main(args)

Hey @saurabh3949, this code snippet is not using Ray

Oh sorry. Here you go. @amogkam Pls let me know if you are not able to replicate the issue with this snippet.

import logging

import ray
import torch
from ray.air import session
from ray.air.config import ScalingConfig
from ray.train.torch import TorchTrainer
from torch.nn.parallel import DistributedDataParallel as DDP

logger = logging.getLogger()


class TinyModel(torch.nn.Module):

    def __init__(self):
        super(TinyModel, self).__init__()

        self.linear1 = torch.nn.Linear(100, 200)
        self.activation = torch.nn.ReLU()
        self.linear2 = torch.nn.Linear(200, 10)
        self.softmax = torch.nn.Softmax()

    def forward(self, x):
        x = self.linear1(x)
        x = self.activation(x)
        x = self.linear2(x)
        x = self.softmax(x)
        return x


def save_model():
    if session.get_local_rank() == 0:
        raise ValueError
    else:
        print("Waiting for rank 0 to raise error")

    torch.distributed.barrier()


def train_fn(config):
    local_rank = session.get_local_rank()
    device_id = local_rank
    device = torch.device("cuda", device_id)
    n_gpu = torch.cuda.device_count()
    torch.cuda.set_device(device)
    tinymodel = TinyModel()
    tinymodel.to(device)
    
    # If you comment out the line below, then the job will exit immediately. 
    ddp_model = DDP(tinymodel, device_ids=[local_rank], output_device=local_rank)
    save_model()


if __name__ == '__main__':
    ray.init()
    use_gpu = True

    trainer = TorchTrainer(
        train_fn,
        scaling_config=ScalingConfig(
            num_workers=6, use_gpu=use_gpu)
    )

    results = trainer.fit()

I feel it has something to do with NCCL. If you comment out the line that wraps the model in DDP, the ray job immediately throws the error.
It gets stuck otherwise.

Supplying a lower timeout_s to TorchConfig helps, but I’d still expect ray to throw the error immediately.

    trainer = TorchTrainer(
        train_fn,
        scaling_config=ScalingConfig(
            num_workers=6, use_gpu=use_gpu),
        torch_config=TorchConfig(timeout_s=10)
    )

Hey @saurabh3949, to clarify you are using nightly version of Ray right?

Yes, Could you please point me to your commit where you fixed the issue before?

Here is the PR: [Train] Immediately fail if application errors on any worker by amogkam · Pull Request #28314 · ray-project/ray · GitHub.

Let me also try your code snippet

Yes, I can confirm that the nightly build I am using has your commits.

Thanks @saurabh3949- this should now be fixed once [Train] Immediately fail on any worker failure by amogkam · Pull Request #29927 · ray-project/ray · GitHub is merged