- High: It blocks me to complete my task.
Hello,
The purpose of my task is to increase train_batch_size
, but I am running into GPU memory limitations. Ideally, I would like a CPU buffer for all samples, which loads only one batch size of sgd_minibatch_size
on the fly, creating an independent train_batch_size
. I am unsure how to configure this. (if it exists)
Sorry for the brief title, but there are two issues I have noticed with no (at least that is documented) ways of working around:
-
RLLib seems to be lazy about allocating gpu memory for training. We are training with 24GB RTX6000’s and unable to use the full capacity of memory because of these additional allocations. The first iteration allocates 17GB, then the second allocates 8GB more via ‘load_batch_into_buffer’
-
train_batch_size
seems to be the driver of GPU memory allocation, which is counter-intuitive, becausesgd_minibatch_size
is documented as being the driver of GPU memory allocation.
Observation size of roughly 1000 inputs
~9GB first iteration
rollout_fragment_length: 1000
sgd_minibatch_size: 20480
num_sgd_iter: 4
train_batch_size: 512000
batch_mode: complete_episodes
~17GB first iteration
rollout_fragment_length: 1000
sgd_minibatch_size: 2048
num_sgd_iter: 4
train_batch_size: 1024000
batch_mode: complete_episodes
~17GB first iteration
rollout_fragment_length: 1000
sgd_minibatch_size: 20480
num_sgd_iter: 4
train_batch_size: 1024000
batch_mode: complete_episodes
Failure # 1 (occurred at 2022-07-21_05-26-01)
Traceback (most recent call last):
File "/opt/conda/lib/python3.8/site-packages/ray/tune/trial_runner.py", line 893, in _process_trial
results = self.trial_executor.fetch_result(trial)
File "/opt/conda/lib/python3.8/site-packages/ray/tune/ray_trial_executor.py", line 707, in fetch_result
result = ray.get(trial_future[0], timeout=DEFAULT_GET_TIMEOUT)
File "/opt/conda/lib/python3.8/site-packages/ray/_private/client_mode_hook.py", line 105, in wrapper
return func(*args, **kwargs)
File "/opt/conda/lib/python3.8/site-packages/ray/worker.py", line 1733, in get
raise value.as_instanceof_cause()
ray.exceptions.RayTaskError(RuntimeError): ray::TempestPPOTrainer.train() (pid=249, ip=10.42.143.37, repr=TempestPPOTrainer)
File "/opt/conda/lib/python3.8/site-packages/ray/tune/trainable.py", line 315, in train
result = self.step()
File "/opt/conda/lib/python3.8/site-packages/ray/rllib/agents/trainer.py", line 982, in step
raise e
File "/opt/conda/lib/python3.8/site-packages/ray/rllib/agents/trainer.py", line 963, in step
step_attempt_results = self.step_attempt()
File "/opt/conda/lib/python3.8/site-packages/ray/rllib/agents/trainer.py", line 1042, in step_attempt
step_results = self._exec_plan_or_training_iteration_fn()
File "/opt/conda/lib/python3.8/site-packages/ray/rllib/agents/trainer.py", line 1962, in _exec_plan_or_training_iteration_fn
results = next(self.train_exec_impl)
File "/opt/conda/lib/python3.8/site-packages/ray/util/iter.py", line 756, in __next__
return next(self.built_iterator)
File "/opt/conda/lib/python3.8/site-packages/ray/util/iter.py", line 783, in apply_foreach
for item in it:
File "/opt/conda/lib/python3.8/site-packages/ray/util/iter.py", line 783, in apply_foreach
for item in it:
File "/opt/conda/lib/python3.8/site-packages/ray/util/iter.py", line 843, in apply_filter
for item in it:
File "/opt/conda/lib/python3.8/site-packages/ray/util/iter.py", line 843, in apply_filter
for item in it:
File "/opt/conda/lib/python3.8/site-packages/ray/util/iter.py", line 783, in apply_foreach
for item in it:
File "/opt/conda/lib/python3.8/site-packages/ray/util/iter.py", line 783, in apply_foreach
for item in it:
File "/opt/conda/lib/python3.8/site-packages/ray/util/iter.py", line 791, in apply_foreach
result = fn(item)
File "/opt/conda/lib/python3.8/site-packages/ray/rllib/execution/train_ops.py", line 310, in __call__
num_loaded_samples[policy_id] = self.local_worker.policy_map[
File "/opt/conda/lib/python3.8/site-packages/ray/rllib/policy/torch_policy.py", line 490, in load_batch_into_buffer
slices = [
File "/opt/conda/lib/python3.8/site-packages/ray/rllib/policy/torch_policy.py", line 491, in <listcomp>
slice.to_device(self.devices[i]) for i, slice in enumerate(slices)
File "/opt/conda/lib/python3.8/site-packages/ray/rllib/policy/sample_batch.py", line 661, in to_device
self[k] = torch.from_numpy(v).to(device)
RuntimeError: CUDA out of memory. Tried to allocate 8.48 GiB (GPU 0; 23.65 GiB total capacity; 17.30 GiB already allocated; 5.25 GiB free; 17.42 GiB reserved in total by PyTorch)