Ray.get() becomes very slow when i increase the number of epochs

I use ray.get(current_weights) to get the final weights of my model after finishing all epochs. The problem is when i use for example 2 epochs, ray.get(current_weights) takes 17s, and with 100 epochs takes 1000s. I don’t know why ?

It seems that ray.get takes more time when we call remote function many time ?

    for epoch in range(n_epoch):
        start_epoch = time.time()
        for b in range(total_batch):
            gradients = [worker.compute_gradients.remote(current_weights) for worker in workers]
            current_weights = ps.apply_gradients.remote(*gradients)
        if (epoch + 1) % n_epoch == 0:
            weights = ray.get(current_weights) # This line is so slow if i increase n_epoch

I’m using parameter server for gradient sharing.

How severe does this issue affect your experience of using Ray?

  • High: It blocks me to complete my task.

In general, ray.get will take longer if passed more objects due to the time needed to fetch all of those objects to the current process.

However, from your code snippet, it appears that the slowdown is actually just because your script is generating more work. The .remote calls are instant since the actual functions execute asynchronously, so all of the execution time will appear to be in the ray.get call. When n_epoch increases, the script submits more tasks to the worker and ps actors. The ray.get call has to wait for all of these actor tasks to finish before returning.

1 Like

Thank you @Stephanie_Wang for your answer. Do you have any suggestion to speed up the training ? I’m using parameter server to share the gradients between workers and parameter server.

Unfortunately it is very hard to give specific advice for this since performance debugging in general is a difficult problem.

However, let me point you to the docs for debugging and profiling on Ray core. I would also highly recommend that you look into Ray Train, a library for distributed training, instead of building directly over Ray core APIs.