Hi Everyone! We are trying to use Ray Train to work for distributed training. We have this use case where we need each worker to be synchronized in between the training. That is to run a function for some time, synchronize, and then run some other functions. So, is there a way to achieve it?
@pratkpranav I’m not quite clear what you mean by
- each worker to be synchronized between training
- run function for sometime, synchronize, run other functions
In what context are you using “synchronize” here? Perhaps a concrete example would help?
@Jules_Damji Sorry if it was not very clear. I was wondering whether there is a way to have functionality similar to MPI_Barrier
, which blocks all the workers until all of them have called that function and then continues running from the next lines synchronously. Something like this:
def training_loop_per_worker(config):
model = config.get('model')
model.train()
for batch, (X, y) in enumerate(dataloader):
# Compute prediction error
pred = model(X)
loss = loss_fn(pred, y)
<wait for all the workers to reach this line>
# Backpropagation
optimizer.zero_grad()
loss.backward()
optimizer.step()
Hi @pratkpranav,
The main API to interact with Ray Train from inside your custom training loop is session.report
, and this actually does serve as a synchronization barrier for workers. Training will only progress once all workers have reported.
However, session.report
is mainly used to report metrics and checkpoints at the end of each epoch. If you want to use it as a barrier, you’d have to report some dummy metrics. What’s the intended use case for this?
Synchronizing gradients between workers is already handled by the Torch distributed backend – just make sure you call ray.train.torch.prepare_model
and ray.train.torch.prepare_data_loader
on your model/dataloader. See here.
Hi @justinvyu ,
Thanks for your reply. We are currently working on a machine-learning engine and creating a distributed training framework using Ray Trainer for it. Unfortunately, our backend currently lacks the capability to pause model training before gradient synchronization. Thus, I’m curious if Ray Trainer can solve this issue. Although session.report
reports metrics and checkpoints at the end of each epoch, it may not be the most suitable approach in this case. I’m wondering if there are alternative methods available to accomplish our goal.
Ray actually does provide some basic communication primitives (including barrier). Take a look here: Ray Collective Communication Lib — Ray 2.4.0
This blog post may also be of use: Introducing Collective Communication Primitive APIs in Ray | Anyscale
I believe this is not under active development at the moment, but let me know if this suits your needs and if you run into any difficulties.
You can also consider using the communication backend of torch.distributed
: Distributed communication package - torch.distributed — PyTorch 2.0 documentation
Thanks @justinvyu for the responses.
Thanks @justinvyu, @Jules_Damji! This is really helpful.