Lets say I am using IMPALA with several workers and I have 4 GPUs.
I want to be able to say GPUs 0,1,2 should be shared across the workers evenly so inference/rendering is accellerated by them, But i want GPU 3 to be dedicated to training loop.
This should probably work out-of-the-box in the latest master.
I would probably set num_gpus_per_worker=[3/num_workers] and our new PlacementGroups support should take care of the rest.
If you need more details on how we do this now in RLlib (and IMPALA), you can you look at IMPALAās DefaultResourceRequest override (in rllib/agents/impala.py::OverrideDefaultResourceRequest).
In that method, we define the placement group bundles depending on how many workers we have. The first bundle is always the learner one (ādriverā), then the worker bundles (which we will attempt to PACK (on the same node if possible), then the eval workers (you may not have any)).
On the latest master this just fails with an error out-of-the-box:
(pid=13371) 2021-04-18 10:16:52,677 ERROR worker.py:395 -- Exception raised in creation task: The actor died because of an error raised in its creation task, ray::RolloutWorker.__init__() (pid=13371, ip=192.168.1.105)
(pid=13371) File "python/ray/_raylet.pyx", line 505, in ray._raylet.execute_task
(pid=13371) File "python/ray/_raylet.pyx", line 449, in ray._raylet.execute_task.function_executor
(pid=13371) File "/home/bam4d/anaconda3/envs/griddly/lib/python3.8/site-packages/ray/_private/function_manager.py", line 566, in actor_method_executor
(pid=13371) return method(__ray_actor, *args, **kwargs)
(pid=13371) File "/home/bam4d/anaconda3/envs/griddly/lib/python3.8/site-packages/ray/rllib/evaluation/rollout_worker.py", line 516, in __init__
(pid=13371) self.policy_map, self.preprocessors = self._build_policy_map(
(pid=13371) File "/home/bam4d/anaconda3/envs/griddly/lib/python3.8/site-packages/ray/rllib/evaluation/rollout_worker.py", line 1158, in _build_policy_map
(pid=13371) policy_map[name] = cls(obs_space, act_space, merged_conf)
(pid=13371) File "/home/bam4d/anaconda3/envs/griddly/lib/python3.8/site-packages/ray/rllib/policy/policy_template.py", line 243, in __init__
(pid=13371) self.parent_cls.__init__(
(pid=13371) File "/home/bam4d/anaconda3/envs/griddly/lib/python3.8/site-packages/ray/rllib/policy/torch_policy.py", line 145, in __init__
(pid=13371) self.model_gpu_towers = nn.parallel.replicate.replicate(
(pid=13371) AttributeError: 'function' object has no attribute 'replicate'
If i fix the replicate.replicate() line then I run into another error:
File "/home/bam4d/anaconda3/envs/griddly/lib/python3.8/site-packages/ray/rllib/policy/torch_policy.py", line 145, in __init__
self.model_gpu_towers = nn.parallel.replicate(
File "/home/bam4d/anaconda3/envs/griddly/lib/python3.8/site-packages/torch/nn/parallel/replicate.py", line 91, in replicate
param_copies = _broadcast_coalesced_reshape(params, devices, detach)
File "/home/bam4d/anaconda3/envs/griddly/lib/python3.8/site-packages/torch/nn/parallel/replicate.py", line 71, in _broadcast_coalesced_reshape
tensor_copies = Broadcast.apply(devices, *tensors)
File "/home/bam4d/anaconda3/envs/griddly/lib/python3.8/site-packages/torch/nn/parallel/_functions.py", line 14, in forward
assert all(i.device.type != 'cpu' for i in inputs), (
AssertionError: Broadcast function not implemented for CPU tensors
The parallel.replicate does not work because the model is not on the GPU when its replicated, so it runs into the error above. In addition to thisā¦ if I fix this by moving the model.to(device) then im creating at least 2 models on the GPU, which causes the GPU to run out of memory.
Removing the multi-GPU code and ensuring a single model is placed on the GPU solves my issues. (obviously not a proper fix).
# # Create multi-GPU model towers, if necessary.
# # - The central main model will be stored under self.model, residing on
# # self.device.
# # - Each GPU will have a copy of that model under
# # self.model_gpu_towers, matching the devices in self.devices.
# # - Parallelization is done by splitting the train batch and passing
# # it through the model copies in parallel, then averaging over the
# # resulting gradients, applying these averages on the main model and
# # updating all towers' weights from the main model.
# # - In case of just one device (1 (fake) GPU or 1 CPU), no
# # parallelization will be done.
# if config["_fake_gpus"] or config["num_gpus"] == 0 or \
# not torch.cuda.is_available():
# logger.info(
# "TorchPolicy running on {}.".format("{} fake-GPUs".format(
# config["num_gpus"]) if config["_fake_gpus"] else "CPU"))
# self.device = torch.device("cpu")
# self.model = model.to(self.device)
# self.devices = [
# self.device for _ in range(config["num_gpus"] or 1)
# ]
# self.model_gpu_towers = [
# model if config["num_gpus"] == 0 else copy.deepcopy(model)
# for i in range(config["num_gpus"] or 1)
# ]
# else:
# logger.info("TorchPolicy running on {} GPU(s).".format(
# config["num_gpus"]))
# self.device = torch.device("cuda")
# self.model = model.to(self.device)
# print(ray.get_gpu_ids())
# self.devices = [
# torch.device("cuda:{}".format(id_))
# for i, id_ in enumerate(ray.get_gpu_ids())
# if i < config["num_gpus"]
# ]
# self.model_gpu_towers = nn.parallel.replicate(
# model, [
# id_ for i, id_ in enumerate(ray.get_gpu_ids())
# if i < config["num_gpus"]
# ])
#print(f'{torch.cuda.device_count()}')
self.device = torch.device('cuda:0')
self.devices = [self.device]
self.model = model.to(self.device)
self.model_gpu_towers = [self.model]
So for me, the model.to(self.device) adds one model to the GPU and then the nn.parallel.replicate adds another. so if i only have 1 GPU per worker, it actually loads two models.
Yeah, I see, it should not do any multi-GPU copying/tower generation if num_gpus <= 1. Makes sense. Iāll fix this.
Even for num_gpus=2, we could probably nix the āmainā model and just use one of the towers.
Hey @Bam4d , Here is a preliminary PR with major fixes for torch multi-GPU. Unfortunately, Iām still not able to get num_gpus_per_worker>0 to work. a strange torch error comes up: āInvalid device ordinalā.
Any ideas? Iām running IMPALA on 4 GPUs against Atari: rllib/tuned_examples/impala/atari-impala-multi-gpu.yaml
GPU utilization looks also good now:
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 450.51.05 Driver Version: 450.51.05 CUDA Version: 11.0 |
|-------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|===============================+======================+======================|
| 0 Tesla K80 On | 00000000:00:17.0 Off | 0 |
| N/A 72C P0 114W / 149W | 2702MiB / 11441MiB | 70% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
| 1 Tesla K80 On | 00000000:00:18.0 Off | 0 |
| N/A 56C P0 123W / 149W | 2612MiB / 11441MiB | 53% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
| 2 Tesla K80 On | 00000000:00:19.0 Off | 0 |
| N/A 72C P0 119W / 149W | 2612MiB / 11441MiB | 100% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
| 3 Tesla K80 On | 00000000:00:1A.0 Off | 0 |
| N/A 56C P0 126W / 149W | 2612MiB / 11441MiB | 98% Default |
| | | N/A |
+-------------------------------+----------------------+----------------------+
So the invalid device ordinal might be because the ray_gpu_id is not consistent with cuda.available_devices. I cam across this problem before with Griddly, If there is only 1 device available to cuda, it will always re-order the device to cuda:0.
so if ray_gpu_id is 1 and CUDA_AVAILABLE_DEVICES = 1 then cuda:0 will point to the single gpu and cuda:1 will throw the device ordinal error.
Iām guessing you are passing CUDA_AVAILABLE_DEVICES to worker processes? This might be the cause.
This line here I think is the problem, you are using the ray_gpu_id, but cuda will set these to 0 indexed. I think if you set ā_idā to āiā it will work for this case, not sure about other cases, might be a bit more complicated.
Cool, yeah, that could be that torch always starts enumerating from 0. For the learner, this is fine (it gets the first n gpu IDs), but for the workers, this would then be a problem. Let me try to fix this. ā¦
Thanks so much for your help with this btw.
I did confirm yesterday that e.g. IMPALA runs (and learns) 4x as fast with 4 GPUs as opposed to w/ 1 GPU with this PR already. Just have to figure out the worker GPUs now ā¦
Awesome, that was it! @Bam4d , could you do me a favor and try it now on the above PR? I just tested with the following config and I do see a 4x speedup (throughput and learning) for 4 GPUs over single GPU.
# Runs on a p2.8xlarge single head node machine.
# Should reach ~400 reward in about 1h and after 15-20M ts.
atari-impala:
env:
BreakoutNoFrameskip-v4
run: IMPALA
config:
framework: torch
rollout_fragment_length: 50
train_batch_size: 4000
num_gpus: 4
num_workers: 2 # 31 on a p2.8xlarge machine to get above results
num_gpus_per_worker: 0.5 # or 0 to get above results
num_envs_per_worker: 5
clip_rewards: True
lr_schedule: [
[0, 0.0005],
[20000000, 0.000000000001],
]