Synchronizing workers during ray train

Hi Everyone! We are trying to use Ray Train to work for distributed training. We have this use case where we need each worker to be synchronized in between the training. That is to run a function for some time, synchronize, and then run some other functions. So, is there a way to achieve it?

@pratkpranav I’m not quite clear what you mean by

  • each worker to be synchronized between training
  • run function for sometime, synchronize, run other functions

In what context are you using “synchronize” here? Perhaps a concrete example would help?

cc: @Yard1 @amogkam

@Jules_Damji Sorry if it was not very clear. I was wondering whether there is a way to have functionality similar to MPI_Barrier, which blocks all the workers until all of them have called that function and then continues running from the next lines synchronously. Something like this:

def training_loop_per_worker(config):
    model = config.get('model')
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)

        <wait for all the workers to reach this line>

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

Hi @pratkpranav,

The main API to interact with Ray Train from inside your custom training loop is session.report, and this actually does serve as a synchronization barrier for workers. Training will only progress once all workers have reported.

However, session.report is mainly used to report metrics and checkpoints at the end of each epoch. If you want to use it as a barrier, you’d have to report some dummy metrics. What’s the intended use case for this?

Synchronizing gradients between workers is already handled by the Torch distributed backend – just make sure you call ray.train.torch.prepare_model and ray.train.torch.prepare_data_loader on your model/dataloader. See here.

Hi @justinvyu ,

Thanks for your reply. We are currently working on a machine-learning engine and creating a distributed training framework using Ray Trainer for it. Unfortunately, our backend currently lacks the capability to pause model training before gradient synchronization. Thus, I’m curious if Ray Trainer can solve this issue. Although session.report reports metrics and checkpoints at the end of each epoch, it may not be the most suitable approach in this case. I’m wondering if there are alternative methods available to accomplish our goal.

@pratkpranav

Ray actually does provide some basic communication primitives (including barrier). Take a look here: Ray Collective Communication Lib — Ray 2.4.0

This blog post may also be of use: Introducing Collective Communication Primitive APIs in Ray | Anyscale

I believe this is not under active development at the moment, but let me know if this suits your needs and if you run into any difficulties.

You can also consider using the communication backend of torch.distributed: Distributed communication package - torch.distributed — PyTorch 2.0 documentation

Thanks @justinvyu for the responses.

Thanks @justinvyu, @Jules_Damji! This is really helpful.