Fetch action probability distribution from trained policy

How severe does this issue affect your experience of using Ray?

  • High: It blocks me to complete my task.

How can I get the action probability distribution from a trained policy for a particular state using Ray/RLLib version 2?

Tried using policy.compute_single_action(state, full_fetch=True), but that only fetches additional information for the selected action.


Hi @steff,

Others may have a better way but the best way I know of is to use the action_dist_inputs key in the extra fetches dictionary to create your own action distribution and compute the probabilities from there.

@arturn This is the third request for this I have seen this month in the forumns. Perhaps it makes sense to store all the action probabilities in addition to the selected actions.

Thanks! I’ll open a feature request!

@mannyv and yes, that’s how we do it ourselves. Fetch a fresh action distribution and use the provided inputs!

@steff In the meantime, please have a look at our Policy classes! For example in SAC policies we turn action distribution inputs into distributions that you can sample from.

@arturn I’m another one watching that feature request. It’s a blocker for me. I’ve looked at the provided example from arturn but unfortunately that code is way beyond my understanding.

@ihopethiswillfi Do you know what action distribution the algorithm you are looking at is using?
Sample Batches passed around RLlib can contain a field SampleBatch.ACTION_DIST_INPUTS.
That contains what is needed to reconstruct the action distribution object together with the class.

For example, a DiagGaussian distribution will take two vectors as inputs - both should be found in said field of a SampleBatch. So if you are running a PPO training, the batches passed around in your PPO.training_step() method will contain these inputs.

We are currently working on replacing the concept of Policy and will likely include the action distribution itself in such batches in the future.

@arturn Perfect. Yea I really miss the prediction_proba() that I use pretty much every time I train a traditional ML model :slight_smile:

I’m using PPO and my action space is discrete with 3 possible actions. However I think that I should dive more into the inner workings of RLlib and RL in general, before I will be able to solve this. Thanks a lot for your help.

1 Like