Distributed training with uneven inputs

I’m trying to replace torch.DDP with ray.train, and hit the problem of different DDP instance having different number of input batches.

DDP works with the ‘Join’ context manager can fix this problem, as shown below.
How to handle this problem using ray.train?

...
    model = DDP(torch.nn.Linear(1, 1).to(rank), device_ids=[rank])
    with model.join():
        for input in inputs:
            loss = model(input).sum()
            ...

https://pytorch.org/tutorials/advanced/generic_join.html#using-join-with-distributeddataparallel

Hey @ben1, Ray Train supports Torch DDP and you should be able to run this as well.

Yes, I find out that the actually blocking part is train.report.
Due to the uneven inputs, some worker finish the dataloader loop early. But train.report must be called the same number of times by all workers.

How to handle this problem using ray.train?

with model.join():
        iter_idx = 0
        for input in inputs:
            loss = model(input).sum()
            iter_idx += 1
            if iter_idx % report_num == 0:
                train.report(metrics)
            ...

@ben1 It may be easier to just report one set of aggregated metrics at the end of the epoch, rather than every N steps within the loop.

for epoch in range(epochs):
    for input in inputs:  
        ...
    if epoch % report_num == 0:
        train.report(aggregated_metrics)

train.report does act as a barrier to maintain a consistent metric/checkpoint report iteration across workers.

1 Like