[rllib] Retrieve and modify the computed discrete action logits to PPO agent

I’d like to retrieve the logits that the PPO agent’s policy computes for each multidiscrete action, and then modify them so that the action sampling can use my modified logits.

I’m really sorry for the question, but the code is so abstracted that it’s hard for me to tell how I would retrieve the logits and then reset them. Is it in the forward pass of ModelV2? Any direction you could provide would be helpful, thank you.

Hey @lucas_spangher , thanks for the question. You are right, it’s not super intuitive how to get there. Action logits are processed inside the Policy.exploration object. PPO uses the StochasticSampling exploration module under: ray.rllib.utils.exploration.stochastic_sampling.py

What you should do, it you want to modify these before the sampling step is to sub-class StochasticSampling, then override its get_exploration_action method applying your logic on the logits modification and make sure it does the sampling on those modified logits, then returns the actions (just like the parent StochasticSampling does it).

Then in your config, just do:

config:
    exploration_config:
        type: [full path to your new class, e.g. "my_dir.my_exploration.MyExploration"]
        [other c'tor args for your class]

This would tell PPO to use your own exploration class, instead.

Simplifying our “hooks” into the RLlib code is on our list of things to fix in the short term. This includes providing better (and more flexible) callback options to users via an event-registry-based system, as well as using this system to obsolete e.g. the exploration API and other currently hard-coded hooks.

Thanks for the response! I’m going to give it a try. Here what I’m trying to do (implement an ordinal response) so if there’s interest in incorporating that I’d be happy to submit a PR:

Just to confirm, the action_dist is the logits, correct?

Hi @lucas_spangher,

The action distribution depends on your action space. It will be one of these: ray/tf_action_dist.py at master · ray-project/ray · GitHub or the torch equivalent.

If it were Categorical for example, that distribution has an inputs member variable that would be the logits. You could modify those and reassign.action_dist.inputs = action_dist.inputs + 0. I think that would be all you need in the constructor of your subclass of StochasticSampling.

Your other option is to implement your own custom model and then modify the logits as your last operation before returning them from the forward method. As I think about it more that might be a more straightforward approach if you want to support multiple different kinds of action distributions.

Hey mannyv, thank you for the answer. This was incredibly clear and straightforward. I should have no trouble implementing what I want to implement. I will keep you posted when I’m able to get it.

1 Like