Policy.compute_log_likelihoods should allows to compute with/without applying the exploration (e.g. SoftQ exploration)
This allow to compute the true log_likelihoods of the actual actions taken (with the impact of the exploration).
Something like:
(see at the end after # ADDITION BELOW)
def compute_log_likelihoods(
policy,
actions: Union[List[TensorType], TensorType],
obs_batch: Union[List[TensorType], TensorType],
state_batches: Optional[List[TensorType]] = None,
prev_action_batch: Optional[Union[List[TensorType], TensorType]] = None,
prev_reward_batch: Optional[Union[List[TensorType], TensorType]] = None,
apply_exploration: bool= False,
) -> TensorType:
if policy.action_sampler_fn and policy.action_distribution_fn is None:
raise ValueError(
"Cannot compute log-prob/likelihood w/o an "
"`action_distribution_fn` and a provided "
"`action_sampler_fn`!"
)
with torch.no_grad():
input_dict = policy._lazy_tensor_dict(
{SampleBatch.CUR_OBS: obs_batch, SampleBatch.ACTIONS: actions}
)
if prev_action_batch is not None:
input_dict[SampleBatch.PREV_ACTIONS] = prev_action_batch
if prev_reward_batch is not None:
input_dict[SampleBatch.PREV_REWARDS] = prev_reward_batch
seq_lens = torch.ones(len(obs_batch), dtype=torch.int32)
state_batches = [
convert_to_torch_tensor(s, policy.device)
for s in (state_batches or [])
]
# Exploration hook before each forward pass.
policy.exploration.before_compute_actions(explore=False)
# Action dist class and inputs are generated via custom function.
if policy.action_distribution_fn:
# Try new action_distribution_fn signature, supporting
# state_batches and seq_lens.
try:
(
dist_inputs,
dist_class,
state_out,
) = policy.action_distribution_fn(
policy,
policy.model,
input_dict=input_dict,
state_batches=state_batches,
seq_lens=seq_lens,
explore=False,
is_training=False,
)
# Trying the old way (to stay backward compatible).
# TODO: Remove in future.
except TypeError as e:
if (
"positional argument" in e.args[0]
or "unexpected keyword argument" in e.args[0]
):
dist_inputs, dist_class, _ = policy.action_distribution_fn(
policy=policy,
model=policy.model,
obs_batch=input_dict[SampleBatch.CUR_OBS],
explore=False,
is_training=False,
)
else:
raise e
# Default action-dist inputs calculation.
else:
dist_class = policy.dist_class
dist_inputs, _ = policy.model(input_dict, state_batches, seq_lens)
action_dist = dist_class(dist_inputs, policy.model)
# ADDITION BELOW
if apply_exploration and policy.config["explore"]:
# Using that because of a "bug" in TorchCategorical
# which modify dist_inputs through action_dist:
_ = policy.exploration.get_exploration_action(
action_distribution=action_dist,
timestep=policy.global_timestep,
explore=policy.config["explore"],
)
action_dist = dist_class(dist_inputs, policy.model)
log_likelihoods = action_dist.logp(input_dict[SampleBatch.ACTIONS])
return log_likelihoods