In each train_loop_per_worker, the loss is local . How can I get the report the global loss ?
Hi,
I think you should to get the global loss in a distributed PyTorch training setup you can use collective communication operations like torch distributed all reduce to aggregate the local losses from all workers. Here is a brief outline:
- Calculate the local loss on each worker.
- Use
torch.distributed.all_reduce
to sum up the local losses. - Divide the aggregated loss by the number of workers to get the global loss.
Here is a simplified example:
import torch
import torch.distributed as dist
def train_loop_per_worker():
# ... your training code ...
local_loss = compute_local_loss() # Replace with your loss computation
# Aggregate the local losses to get the global loss
dist.all_reduce(local_loss, op=dist.ReduceOp.SUM)
global_loss = local_loss / dist.get_world_size()
# Report the global loss
return global_loss.item()
Ensure that your training script initializes the distributed process group appropriately using dist.init_process_group()
.
Thanks
why doesn’t Ray.Train provide it since it is a common case in distributed tranning ? Just curious
Hi,
Since Ray manages to wrap pytorch things for me , how to use dist.init_process_group() in Ray context?
Regards
You can also use existing Torch ecosystem libraries that do this for you, e.g. torchmetrics
.
See an example here.
1 Like