How do you get action probabilities from a policy?

How severe does this issue affect your experience of using Ray?

  • High: It blocks me to complete my task.

Given a state and a trained policy, how can I compute the action distribution for that state under the policy? Looking at the docs, it sounds to me like that’s what the compute_log_likelihoods function is for, but it doesn’t behave as expected and when I asked about that function specifically I got no answers.

I’ve also looked at the example Querying a policy’s action distribution, but my results using this method aren’t making sense either. In the terminology of that example, I would expect

np.sum([np.e ** dist.logp(a) for a in actions])

to equal the size of the state space |S|, but instead I’m getting a much smaller number, so these logps must not mean what I think, i.e. they are not log(P(action|state)). The example is outdated (from_batch is deprecated, for instance), so I had to make some changes; maybe I’m getting the wrong distribution somehow?

Why is this so hard to find clear/consistent documentation on? Computing an action distribution is one of the simplest things you could possibly want to do with an RL agent. I’d be happy to submit a PR with better docs on this if someone can explain what the intended solution is.

Hi @jwarley ,

Sorry that this has taken a while. I have responded on your GitHub issue:

The log_likelihoods that you compute are representations of Q-Values because you are using DQN.
Since the Q Values between different state-action pairs in your environment don’t vary a lot, they are very close. Luckily, DQN uses the Epsilon Greedy action sampling function and therefore almost always simply picks the action with the slightly higher Q Value. The likelihoods that you print in your repro script underline this - right is always slightly higher than left.


1 Like

@arturn I’ve actually just had this exact same problem come up, but I’m still unsure. Is there no easy built-in way of getting action probabilities for a discrete action space? Basically, an equivalent of algorithm.compute_single_action(), but the output is an entire distribution? E.g. {0: 0.4, 1:0.6} if there’s two actions, and action 0 has probability 0.4 in the current state.

If there isn’t a built-in way, do I understand correctly that I’d have to (1) get the preprocessor and preprocess the observation from my env, (2) get the logits from policy.model (3) create an action_dist (4) get the logp for a given action from that, and (5) and then take the exponent of that - should that directly give me the action probability?

And per jwarley’s original question and your answer to it, does this mean this won’t work for DQN policies? Also not SimpleQ then? Are there any other policies for which this won’t work? Is there any way of doing this for arbitrary policy classes?

Thank you!

Hi @mgerstgrasser,

If you call algorithm.compute_single_action(…, full_fetch=True)

You should get a 3-tupe back with (action, state, extra_fetches)

extra_fetches will be a dictionary and you should be able to get the probabilities with extra_fetches[SampleBatch.ACTION_PROB]

1 Like

@mgerstgrasser , would you mind opening an issue with a repro script and a short description of your hardware? We can then maybe take Manny’s approach to the next level and compute a slope for the queue to warn if it stays positive.


I think you replied on the wrong thread.

Thanks @mannyv , you are right.

Oh, thank you! That only returns the action prob for the sampled action, not all of them, but in the specific setting where I need this there are only two actions! :smiley: Thank you!

If you wanted the probabilities for all the actions you could create a new subclass of the policy you are using and implement extra_actions_out that adds an entry with probabilities for all the actions. That method takes the action distribution as an input so it should be pretty straightforward.

In this case you may also need to update the algorithm to pick the new policy. I am still trying to wrap my head around all the 2.0 api updates.