PPO is using too much GPU memory

Hey I’m moving my discussion about PPO memory usage with @sven1977 from slack to here so that everyone can profit.

Me:

Hey guys,currently trying to optimise PPO memory usage on GPU with torch. From what I understand PPO uses pinned_memory so if I have train_batch_size = 43200 it is going to pre-allocate on GPU to have what’s necessary to store all 43200 values. Then it slices the 43200 values into slices of size sgd_minibatch_size = 1350 . But if I increase sgd_minibatch_size I get a CUDA OOM error. From my reasoning I should be able to use sgd_minibatch_size = train_batch_size . Am I not understanding this correctly? Just want to precise that I’m using 2 GPUs

Also I find it weird that I’m having OOM errors because if I compute the size of my batch, it is way less that what torch is giving itself.

Sven:

Hey, your batch size is quite large. I don’t think we preallocate anything on the GPU (afair), just copy sub-batches as necessary from CPU to GPU, one at a time. There is still an issue in RLlib with PPO+GPUs in the sense that we copy too many times the data from CPU to GPU. PPO does sub-sampling, by default 30 times, over the same batch, which leads to unnecessarily many copy steps here. Our TFMultiGPU optimizer handles this correctly in just copying the entire batch once, then doing the sub-sampling already on the GPU.

Me:

Hello thanks for your answer! What do you mean with “PPO does sub-sampling, by default 30 times, over the same batch” exactly?From what I understand about the implementation of PPO in rllib (train_batch_size / sgd_minibatch_size) * num_sgd_iter gives the number of optimization iterations. Where is the subsampling happening exactly?Is there a way for me to code a quick fix and reduce GPU memory on top the existing code in your opinion?

RLlib will pass dummy tensors through your model during the first pass, preallocating a ton of GPU memory. These tend to be larger in size than my training batches. If your OOM is happening on initialisation, I suggest you ensure it is not the large dummy tensors causing this issue.

Hey @smorad thanks for the answer,

the OOM is happening when the backward passes start, not when initializing using dummy tensors. To be more precise, chronologically what’s happening is:

  1. init networks using dummy tensors and stuff
  2. collect trajectories
  3. optimize PPO losses
  4. go to 2.

my memory usage is low during 1. and 2. but very high during 3. Or maybe I misunderstood what you meant and PPO is using dummy tenors at the start of 3. ?

Hey @floepfl , there is a PR in review right now fixing some of the torch GPU related problems, particularly for PPO by doing proper batch pre-loading (instead of repeated re-loading of the same data for different SGD iters).

It also makes PPO 33% faster on 1GPU and >60% faster when using 2 GPUs (wrt the previous RLlib releases). Please let us know, whether this fixes some of your OOM issues or not.