Hi @rajfly & @avnishn,
You guys are going to like this one. I was able to reproduce the same curves as @rajfly with his reproduction script. Nice job providing one. Here is where it gets interesting. I was also able to reproduce it, or at least something similar, with tensorflow depending on if I use framework: tf
(pink) or tf2
(green). These results are from master as of Fri at 18:00 EST.
Take a look:
What is the difference? Well it appears that tf
and tf2
treat model.alpha differently. In tf
it appears to change as a function of log_alpha, whereas with tf2
it is a static tensor that does not change.
Take a look:
OK well now that we have a theory lets go see what torch is doing.
It is explicitly calculating it in the loss function so it should look like tf
.
Alright that matches what we predicted. I think we have a second prediction. If we can make the alpha in torch behave like tf2
then it should also solve CartPole. Lets try.
Yup that did it.
So now I am wondering. Are there any other algorithms in rllib that are being affected by this issue? I am not sure but I will leave that investigation to the paid professionals. =)
…
diff --git a/rllib/algorithms/sac/sac.py b/rllib/algorithms/sac/sac.py ⇣1.64 KiB/s ⇡840 B/s
index b26bc26e0..483b9ede3 100644
--- a/rllib/algorithms/sac/sac.py
+++ b/rllib/algorithms/sac/sac.py
@@ -105,6 +105,8 @@ class SACConfig(AlgorithmConfig):
self.use_state_preprocessor = DEPRECATED_VALUE
self.worker_side_prioritization = DEPRECATED_VALUE
+ self.model_alpha = True
+
@override(AlgorithmConfig)
def training(
self,
@@ -126,6 +128,7 @@ class SACConfig(AlgorithmConfig):
_deterministic_loss: Optional[bool] = NotProvided,
_use_beta_distribution: Optional[bool] = NotProvided,
num_steps_sampled_before_learning_starts: Optional[int] = NotProvided,
+ model_alpha = NotProvided,
**kwargs,
) -> "SACConfig":
"""Sets the training related configuration.
@@ -282,6 +285,8 @@ class SACConfig(AlgorithmConfig):
self.num_steps_sampled_before_learning_starts = (
num_steps_sampled_before_learning_starts
)
+ if model_alpha is not NotProvided:
+ self.model_alpha = model_alpha
return self
diff --git a/rllib/algorithms/sac/sac_torch_model.py b/rllib/algorithms/sac/sac_torch_model.py
index 6e98a624b..629573a56 100644
--- a/rllib/algorithms/sac/sac_torch_model.py
+++ b/rllib/algorithms/sac/sac_torch_model.py
@@ -110,6 +110,7 @@ class SACTorchModel(TorchModelV2, nn.Module):
torch.from_numpy(np.array([np.log(initial_alpha)])).float()
)
self.register_parameter("log_alpha", log_alpha)
+ self.alpha = torch.exp(self.log_alpha).detach()
# Auto-calculate the target entropy.
if target_entropy is None or target_entropy == "auto":
diff --git a/rllib/algorithms/sac/sac_torch_policy.py b/rllib/algorithms/sac/sac_torch_policy.py
index 4bb56d282..1d241d4d9 100644
--- a/rllib/algorithms/sac/sac_torch_policy.py
+++ b/rllib/algorithms/sac/sac_torch_policy.py
@@ -205,7 +205,10 @@ def actor_critic_loss(
SampleBatch(obs=train_batch[SampleBatch.NEXT_OBS], _is_training=True), [], None
)
- alpha = torch.exp(model.log_alpha)
+ if policy.config["model_alpha"]:
+ alpha = model.alpha.to(model.log_alpha.device)
+ else:
+ alpha = torch.exp(model.log_alpha)
# Discrete case.
if model.discrete:
@@ -351,6 +354,7 @@ def actor_critic_loss(
model.tower_stats["actor_loss"] = actor_loss
model.tower_stats["critic_loss"] = critic_loss
model.tower_stats["alpha_loss"] = alpha_loss
+ model.tower_stats["alpha_value"] = alpha
# TD-error tensor in final stats
# will be concatenated and retrieved for each individual batch item.
@@ -378,7 +382,8 @@ def stats(policy: Policy, train_batch: SampleBatch) -> Dict[str, TensorType]:
torch.stack(tree.flatten(policy.get_tower_stats("critic_loss")))
),
"alpha_loss": torch.mean(torch.stack(policy.get_tower_stats("alpha_loss"))),
- "alpha_value": torch.exp(policy.model.log_alpha),
+# "alpha_value": torch.exp(policy.model.log_alpha),
+ "alpha_value": torch.mean(torch.stack(policy.get_tower_stats("alpha_value"))),
"log_alpha_value": policy.model.log_alpha,
"target_entropy": policy.model.target_entropy,
"policy_t": torch.mean(torch.stack(policy.get_tower_stats("policy_t"))),