SAC on multi-GPU with Pytorch

I tried to train a SAC agent with multi-gpu because the documentation here states that multi-gpu for SAC is supported in PyTorch. I started with the following sample code:

import ray
from ray import tune
from ray.rllib.agents.sac import SACTrainer
from ray.rllib.examples.env.random_env import RandomEnv

if __name__ == '__main__':
    ray.init(num_gpus=2)
    config = {
        'framework': 'torch',
        'num_workers': 0,
        'num_envs_per_worker': 1,
        'num_gpus': 2,
        'env': RandomEnv,
    }

    stop = {
        'timesteps_total': 10000,
    }

    result = tune.run(SACTrainer, config=config, stop=stop)

Then I got the error:
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking arugment for argument tensors in method wrapper__cat)

I fixed it by changing the code from:

into:

     td_error = torch.cat(
        [
            getattr(t, "td_error", torch.tensor([0.0])).to("cpu")
            for t in policy.model_gpu_towers
        ],
        dim=0)

But then I got another error
ValueError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1! (when checking arugment for argument mat1 in method wrapper_addmm)

I then looked further into the code. After debugging for a while, I figured out it’s mostly because of wrong assumptions of where the tensor are. They are correct when there is only one GPU but not when there are multiple. The changes that I did are as follow:
from:
https://github.com/ray-project/ray/blob/master/rllib/agents/sac/sac_torch_policy.py#L182-L185 (I am restricted from putting in more than 2 links as a new user, sorry for the inconvenience)
into:

    target_model_out_tp1, _ = policy.target_model({
        "obs": train_batch[SampleBatch.NEXT_OBS],
        "is_training": True,
    }, [], None)
    target_model_out_tp1 = target_model_out_tp1.to(policy.device)

from:
https://github.com/ray-project/ray/blob/master/rllib/agents/sac/sac_torch_policy.py#L205
into:

    model_deivce = next(model.parameters()).device
    q_tp1 = q_tp1.to(model_deivce)
    q_tp1 -= alpha * log_pis_tp1

from:
https://github.com/ray-project/ray/blob/master/rllib/agents/sac/sac_torch_policy.py#L495-L499
into:

    policy.target_model = policy.target_model.to(policy.device)
    policy.model.log_alpha = policy.model.log_alpha.to(policy.device)
    policy.model.target_entropy = policy.model.target_entropy.to(policy.device)
    for model in policy.model_gpu_towers:
        device = next(model.parameters()).device
        model.log_alpha = model.log_alpha.to(device)
        model.target_entropy = model.target_entropy.to(device)
    ComputeTDErrorMixin.__init__(policy)
    TargetNetworkMixin.__init__(policy)

After these changes, I am able to finish the trial with the sample code for discrete action, but I am not sure whether my changes really fix the problems correctly or will there be other problems.

Here’s some configuration info of my setup:
OS: Ubuntu 16.04
Python: 3.8
Pytorch: 1.9.0
Ray: 1.4.1