Hi there,
I am going over the code of SAC and CQL implementation in RLlib, and noticed a difference in the target Q evaluation part. In SAC, the target Q function will subtract the entropy term which I think is correct. However, in CQL which the Bellman update part is based on SAC, but there is no entropy subtraction . Is there any particular reason for this?
Thanks.
mannyv
September 2, 2021, 2:02am
2
Hi @captainzhao ,
Check out this comment by @michaelzhiluo in an old issue.
opened 04:35AM - 28 Jun 21 UTC
closed 10:25AM - 29 Jun 21 UTC
bug
P2
rllib
### What is the problem?
*Ray version and other system information (Python ve… rsion, TensorFlow version, OS):*
In RLlib, CQL is built on top of SAC. They share the same `actor_loss` and `sac_critic_loss` (`critic_loss` for CQL is `sac_critic_loss + cql_loss`).
SAC `critic_loss` is obtained by
```python
sac_critic_loss = nn.functional.mse_loss(q_t_selected, q_t_target)
```
The problem the that the implementation of `q_t_target` is different in SAC and CQL. In CQL, it does not consider the entropy term `alpha * log_pis_tp1`.
In `sac_torch_poliy.py`
```python
# Target q network evaluation.
q_tp1 = policy.target_model.get_q_values(target_model_out_tp1,
policy_tp1)
q_tp1 -= alpha * log_pis_tp1
q_tp1_best = torch.squeeze(input=q_tp1, dim=-1)
q_tp1_best_masked = (1.0 - train_batch[SampleBatch.DONES].float()) * \
q_tp1_best
# compute RHS of bellman equation
q_t_selected_target = (
train_batch[SampleBatch.REWARDS] +
(policy.config["gamma"]**policy.config["n_step"]) * q_tp1_best_masked
).detach()
```
In `cql_torch_policy.py`
```python
# Target q network evaluation.
q_tp1 = policy.target_model.get_q_values(target_model_out_tp1, policy_tp1)
q_tp1_best = torch.squeeze(input=q_tp1, dim=-1)
q_tp1_best_masked = (1.0 - terminals.float()) * q_tp1_best
# compute RHS of bellman equation
q_t_target = (
rewards +
(discount**policy.config["n_step"]) * q_tp1_best_masked).detach()
```
### Reproduction (REQUIRED)
Please provide a short code snippet (less than 50 lines if possible) that can be copy-pasted to reproduce the issue. The snippet should have **no external library dependencies** (i.e., use fake or mock data / environments):
If the code snippet cannot be run by itself, the issue will be closed with "needs-repro-script".
- [x] I have verified my script runs in a clean environment and reproduces the issue.
- [x] I have verified the issue also occurs with the [latest wheels](https://docs.ray.io/en/master/installation.html).