Performance issue of back-propagation in using RaySGD

Hello Team,

We are currently going to use RaySGD to replace a distributed training approach (referred to as Simple in this thread), which is based on Best Practices: Ray with PyTorch — Ray v2.0.0.dev0 but without PyTorch elastic launch module. After benchmarking, we found a potential performance issue in using RaySGD: the back-propagation steps of each epoch are not scaling well when we use multiple trainers. Here come the details:

  • When we use one trainer for the entire dataset, here are the breakdown numbers for each epoch (forward: accumulated execution time of forward propagation steps; backward: accumulated execution time of back-propagation steps; update: accumulated execution time of updating model params):
+----------+----------+----------+
|          | Simple   | RaySGD   |
+----------+----------+----------+
| forward  | 0.589895 | 0.541958 |
+----------+----------+----------+
| backward | 0.406967 | 0.379009 |
+----------+----------+----------+
| update   | 0.067592 | 0.061035 |
+----------+----------+----------+
  • When we use three trainers on three dataset partitions (i.e., data parallelism), the breakdown numbers of each trainer are:
+----------+-----------+------------+
|          | Simple    | RaySGD     |
+----------+-----------+------------+
| forward  | 0.200291  | 0.181890   |
+----------+-----------+------------+
| backward | 0.147895  | 0.34216928 |
+----------+-----------+------------+
| update   | 0.0158243 | 0.0132865  |
+----------+-----------+------------+

Based on our understanding, since the total number of samples to be processed by each trainer is divided by three if we use three trainers, reasonable breakdown numbers of RaySGD in the second table should be roughly 1/3 of those in the first table. However, the observed breakdown number for RaySGD’s back-propagation is much larger. Even though we have tried to figure out what is the root cause by diving into the source code of RaySGD, we failed to find any clues.

We appreciate your help in advance.

Python: 3.8.3
PyTorch: 1.7.0
Ray: 2.0.0.dev0

1 Like

Hey @haiyangshi.876, thanks for posting this!

Could you provide more details about the “Simple” implementation? Is it exactly like what’s in the guide that you linked?

Also for both the Simple and RaySGD implementation where in the code exactly are you measuring the time?

If you could perhaps provide the code for this experiment that would really help!

1 Like

Hi @amogkam , thanks for your reply.

The “simple” implementation is very similar to that in the guide, please refer to simple_train.py and simple_sage.py on this gist page for more details. For the time measuring, here is the time measuring code for the “simple” implementation and here for the RaySGD implementation.

Hopefully, these code snippets are sufficient. Looking forward to your thought.

1 Like

We’ve figured out what happened. RaySGD-based training worked properly, while there was a bug in the “simple” implementation. The model used in the “simple” implementation was not wrapped by DistributedDataParallel, thus there was no synchronization during back-propagation steps, which of course performed better in distributed settings.

1 Like