How to get the global loss to train with pytorch?

In each train_loop_per_worker, the loss is local . How can I get the report the global loss ?

Hi,

I think you should to get the global loss in a distributed PyTorch training setup you can use collective communication operations like torch distributed all reduce to aggregate the local losses from all workers. Here is a brief outline:

  1. Calculate the local loss on each worker.
  2. Use torch.distributed.all_reduce to sum up the local losses.
  3. Divide the aggregated loss by the number of workers to get the global loss.

Here is a simplified example:

import torch
import torch.distributed as dist

def train_loop_per_worker():
    # ... your training code ...

    local_loss = compute_local_loss()  # Replace with your loss computation

    # Aggregate the local losses to get the global loss
    dist.all_reduce(local_loss, op=dist.ReduceOp.SUM)
    global_loss = local_loss / dist.get_world_size()

    # Report the global loss
    return global_loss.item()

Ensure that your training script initializes the distributed process group appropriately using dist.init_process_group().

Thanks

why doesn’t Ray.Train provide it since it is a common case in distributed tranning ? Just curious

Hi,
Since Ray manages to wrap pytorch things for me , how to use dist.init_process_group() in Ray context?

Regards

You can also use existing Torch ecosystem libraries that do this for you, e.g. torchmetrics.

See an example here.

1 Like