Policy.compute_log_likelihoods should allows to compute with/without applying the exploration (e.g. SoftQ exploration)

Policy.compute_log_likelihoods should allows to compute with/without applying the exploration (e.g. SoftQ exploration)

This allow to compute the true log_likelihoods of the actual actions taken (with the impact of the exploration).

Something like:
(see at the end after # ADDITION BELOW)

def compute_log_likelihoods(
    policy,
    actions: Union[List[TensorType], TensorType],
    obs_batch: Union[List[TensorType], TensorType],
    state_batches: Optional[List[TensorType]] = None,
    prev_action_batch: Optional[Union[List[TensorType], TensorType]] = None,
    prev_reward_batch: Optional[Union[List[TensorType], TensorType]] = None,
    apply_exploration: bool= False,
) -> TensorType:

    if policy.action_sampler_fn and policy.action_distribution_fn is None:
        raise ValueError(
            "Cannot compute log-prob/likelihood w/o an "
            "`action_distribution_fn` and a provided "
            "`action_sampler_fn`!"
        )

    with torch.no_grad():
        input_dict = policy._lazy_tensor_dict(
            {SampleBatch.CUR_OBS: obs_batch, SampleBatch.ACTIONS: actions}
        )
        if prev_action_batch is not None:
            input_dict[SampleBatch.PREV_ACTIONS] = prev_action_batch
        if prev_reward_batch is not None:
            input_dict[SampleBatch.PREV_REWARDS] = prev_reward_batch
        seq_lens = torch.ones(len(obs_batch), dtype=torch.int32)
        state_batches = [
            convert_to_torch_tensor(s, policy.device)
            for s in (state_batches or [])
        ]

        # Exploration hook before each forward pass.
        policy.exploration.before_compute_actions(explore=False)

        # Action dist class and inputs are generated via custom function.
        if policy.action_distribution_fn:

            # Try new action_distribution_fn signature, supporting
            # state_batches and seq_lens.
            try:
                (
                    dist_inputs,
                    dist_class,
                    state_out,
                ) = policy.action_distribution_fn(
                    policy,
                    policy.model,
                    input_dict=input_dict,
                    state_batches=state_batches,
                    seq_lens=seq_lens,
                    explore=False,
                    is_training=False,
                )
            # Trying the old way (to stay backward compatible).
            # TODO: Remove in future.
            except TypeError as e:
                if (
                    "positional argument" in e.args[0]
                    or "unexpected keyword argument" in e.args[0]
                ):
                    dist_inputs, dist_class, _ = policy.action_distribution_fn(
                        policy=policy,
                        model=policy.model,
                        obs_batch=input_dict[SampleBatch.CUR_OBS],
                        explore=False,
                        is_training=False,
                    )
                else:
                    raise e

        # Default action-dist inputs calculation.
        else:
            dist_class = policy.dist_class
            dist_inputs, _ = policy.model(input_dict, state_batches, seq_lens)

        action_dist = dist_class(dist_inputs, policy.model)
        # ADDITION BELOW
        if apply_exploration and policy.config["explore"]:
            # Using that because of a "bug" in TorchCategorical
            #  which modify dist_inputs through action_dist:
            _ = policy.exploration.get_exploration_action(
                action_distribution=action_dist,
                timestep=policy.global_timestep,
                explore=policy.config["explore"],
            )
            action_dist = dist_class(dist_inputs, policy.model)

        log_likelihoods = action_dist.logp(input_dict[SampleBatch.ACTIONS])

        return log_likelihoods

Great idea! We should also probably add the typical explore: bool = False arg to this method.