Access action probs after each episode/env step

How severe does this issue affect your experience of using Ray?

  • None: Just asking a question out of curiosity

Hi all,

Can I access the recent action probs after each episode/env step?
I’ve thought of using a custom callback on_episode_step where I have access to worker, policies and episode objects, but I don’t know which of these objects contain the recent action probs?

Hi Klaus! :wave:t3:

Would you like to ask your question in RLlib Office Hours? :writing_hand:t3: Just add your question to this doc: RLlib Office Hours - Google Docs

Thanks! Hope to see you there!

Done :white_check_mark:
I didn’t know about this opportunity :+1:

Hi Klaus, Sorry we did not get to your question during last office hour. I moved your question to OH on July 5th with Sven. Hope that is OK?

Thanks!
Christy

1 Like

Hey @klausk55 ,
can you try calling:

episode.last_extra_action_outs_for([optional agent ID])

inside your on_episode_step() callback?

This should give you the extra-action outputs, including the action probs, logp, logits, etc… for the most recent action computation step.

class MyCallbacks(DefaultCallbacks):
    def on_episode_step(
        self,
        *,
        worker: RolloutWorker,
        base_env: BaseEnv,
        policies: Dict[str, Policy],
        episode: Episode,
        env_index: int,
        **kwargs
    ):
        # Get logps.
        logps = episode.last_extra_action_outs_for()
        episode.user_data["logps"].append(logps)

    def on_episode_end(
        self,
        *,
        worker: RolloutWorker,
        base_env: BaseEnv,
        policies: Dict[str, Policy],
        episode: Episode,
        env_index: int,
        **kwargs
    ):
        # Make logps available in custom metrics for each episode.
        mean_logps = np.mean(episode.user_data["logps"])
        episode.custom_metrics["logps"] = mean_logps



# In your config, make sure to set up your custom callback class from above:
config["callbacks"] = MyCallbacks

2 Likes

Hey @sven1977,
Yes this gives me the most recent action prob, logp, logits (action dist inputs) and vf prediction :+1:
grafik

Could I also use the logits (action dist inputs) to get probs of all actions and not only of the sampled action (or at least logps)? My policy has a categorical action dist.

Yes, you can use policy.dist_class to get your class and then use action_dist_inputs to get all the logps.

1 Like