SAC algorithm fails when used with pytorch for discrete actions

How severe does this issue affect your experience of using Ray?

  • High: It blocks me to complete my task.

When running the SAC algorithm on cartpole env grid searching over tf2 and torch frameworks following the tuned example here, torch seems to fail while tf2 seems to work. Both torch and tf2 has the same hyperparameters and thus, there seems to be a bug with the torch implementation. The training graph is shown below where torch is the grey line and tf2 is the blue line. When run with the pendulum env, they seem to have similar performance and thus I suspect the bug is with regards to discrete environments.

Versions / Dependencies

Ray 2.2.0
Gym 0.23.1
Python 3.9
Ubuntu

Reproduction script

# set visible gpus
import os
os.environ["CUDA_VISIBLE_DEVICES"]="1,2"

import ray
from ray import air, tune
from ray.rllib.algorithms.sac.sac import SACConfig

rllib_config = SACConfig()\
    .framework(framework=tune.grid_search(['torch', 'tf2']), eager_tracing=True)\
    .environment(env='CartPole-v1', render_env=False,)\
    .resources(num_gpus=1)\
    .debugging(seed=0)\
    .training(
        gamma=0.95,
        target_network_update_freq=32,
        tau=1.0,
        train_batch_size=32,
        optimization_config={'actor_learning_rate': 0.005, 'critic_learning_rate': 0.005, 'entropy_learning_rate': 0.0001},
        )\

air_config = air.RunConfig(
    name='SAC',
    stop={'timesteps_total': 10000000},
    checkpoint_config=air.CheckpointConfig(
        checkpoint_at_end=True
    ),
    local_dir='~/Projects/rl-diff-test/results',
    log_to_file=True,
)

tuner = tune.Tuner(
    'SAC',
    param_space=rllib_config,
    run_config=air_config,
)

ray.init()
tuner.fit()
ray.shutdown()

See the GitHub issue here for more information.

Hi @rajfly & @avnishn,

You guys are going to like this one. I was able to reproduce the same curves as @rajfly with his reproduction script. Nice job providing one. Here is where it gets interesting. I was also able to reproduce it, or at least something similar, with tensorflow depending on if I use framework: tf (pink) or tf2 (green). These results are from master as of Fri at 18:00 EST.

Take a look:

What is the difference? Well it appears that tf and tf2 treat model.alpha differently. In tf it appears to change as a function of log_alpha, whereas with tf2 it is a static tensor that does not change.

Take a look:

OK well now that we have a theory lets go see what torch is doing.

It is explicitly calculating it in the loss function so it should look like tf.

Alright that matches what we predicted. I think we have a second prediction. If we can make the alpha in torch behave like tf2 then it should also solve CartPole. Lets try.

Yup that did it.

So now I am wondering. Are there any other algorithms in rllib that are being affected by this issue? I am not sure but I will leave that investigation to the paid professionals. =)

ā€¦

diff --git a/rllib/algorithms/sac/sac.py b/rllib/algorithms/sac/sac.py                                                                                                                                                                                                  ī‚² ā‡£1.64 KiB/s ā‡”840 B/s 
index b26bc26e0..483b9ede3 100644
--- a/rllib/algorithms/sac/sac.py
+++ b/rllib/algorithms/sac/sac.py
@@ -105,6 +105,8 @@ class SACConfig(AlgorithmConfig):
         self.use_state_preprocessor = DEPRECATED_VALUE
         self.worker_side_prioritization = DEPRECATED_VALUE
 
+        self.model_alpha = True
+
     @override(AlgorithmConfig)
     def training(
         self,
@@ -126,6 +128,7 @@ class SACConfig(AlgorithmConfig):
         _deterministic_loss: Optional[bool] = NotProvided,
         _use_beta_distribution: Optional[bool] = NotProvided,
         num_steps_sampled_before_learning_starts: Optional[int] = NotProvided,
+        model_alpha = NotProvided,
         **kwargs,
     ) -> "SACConfig":
         """Sets the training related configuration.
@@ -282,6 +285,8 @@ class SACConfig(AlgorithmConfig):
             self.num_steps_sampled_before_learning_starts = (
                 num_steps_sampled_before_learning_starts
             )
+        if model_alpha is not NotProvided:
+            self.model_alpha = model_alpha
 
         return self
diff --git a/rllib/algorithms/sac/sac_torch_model.py b/rllib/algorithms/sac/sac_torch_model.py
index 6e98a624b..629573a56 100644
--- a/rllib/algorithms/sac/sac_torch_model.py
+++ b/rllib/algorithms/sac/sac_torch_model.py
@@ -110,6 +110,7 @@ class SACTorchModel(TorchModelV2, nn.Module):
             torch.from_numpy(np.array([np.log(initial_alpha)])).float()
         )
         self.register_parameter("log_alpha", log_alpha)
+        self.alpha = torch.exp(self.log_alpha).detach()
 
         # Auto-calculate the target entropy.
         if target_entropy is None or target_entropy == "auto":
diff --git a/rllib/algorithms/sac/sac_torch_policy.py b/rllib/algorithms/sac/sac_torch_policy.py
index 4bb56d282..1d241d4d9 100644
--- a/rllib/algorithms/sac/sac_torch_policy.py
+++ b/rllib/algorithms/sac/sac_torch_policy.py
@@ -205,7 +205,10 @@ def actor_critic_loss(
         SampleBatch(obs=train_batch[SampleBatch.NEXT_OBS], _is_training=True), [], None
     )
 
-    alpha = torch.exp(model.log_alpha)
+    if policy.config["model_alpha"]:
+       alpha = model.alpha.to(model.log_alpha.device)
+    else:
+       alpha = torch.exp(model.log_alpha)
 
     # Discrete case.
     if model.discrete:
@@ -351,6 +354,7 @@ def actor_critic_loss(
     model.tower_stats["actor_loss"] = actor_loss
     model.tower_stats["critic_loss"] = critic_loss
     model.tower_stats["alpha_loss"] = alpha_loss
+    model.tower_stats["alpha_value"] = alpha
 
     # TD-error tensor in final stats
     # will be concatenated and retrieved for each individual batch item.
@@ -378,7 +382,8 @@ def stats(policy: Policy, train_batch: SampleBatch) -> Dict[str, TensorType]:
             torch.stack(tree.flatten(policy.get_tower_stats("critic_loss")))
         ),
         "alpha_loss": torch.mean(torch.stack(policy.get_tower_stats("alpha_loss"))),
-        "alpha_value": torch.exp(policy.model.log_alpha),
+#        "alpha_value": torch.exp(policy.model.log_alpha),
+        "alpha_value": torch.mean(torch.stack(policy.get_tower_stats("alpha_value"))),
         "log_alpha_value": policy.model.log_alpha,
         "target_entropy": policy.model.target_entropy,
         "policy_t": torch.mean(torch.stack(policy.get_tower_stats("policy_t"))),
1 Like

wow great catch @mannyv

have you opened a pr for your fix, or would like someone on the RLlib team to upstream the fix instead?

Thanks,
@avnishn

@avnish,
It is not really a fix yet. I will open an issue. I think I might see one or two other problems too. I am still investigating to see if they actually are.
If I find multiple would the team prefer seperate issues or one issue with all of them listed together.

Could you share a link to the issue? Iā€™d like to follow it. Thanks!