How severe does this issue affect your experience of using Ray?
- None: Just asking a question out of curiosity
Hi all,
Can I access the recent action probs after each episode/env step?
I’ve thought of using a custom callback on_episode_step
where I have access to worker, policies and episode objects, but I don’t know which of these objects contain the recent action probs?
Hi Klaus!
Would you like to ask your question in RLlib Office Hours? Just add your question to this doc: RLlib Office Hours - Google Docs
Thanks! Hope to see you there!
Done
I didn’t know about this opportunity
Hi Klaus, Sorry we did not get to your question during last office hour. I moved your question to OH on July 5th with Sven. Hope that is OK?
Thanks!
Christy
1 Like
Hey @klausk55 ,
can you try calling:
episode.last_extra_action_outs_for([optional agent ID])
inside your on_episode_step()
callback?
This should give you the extra-action outputs, including the action probs, logp, logits, etc… for the most recent action computation step.
class MyCallbacks(DefaultCallbacks):
def on_episode_step(
self,
*,
worker: RolloutWorker,
base_env: BaseEnv,
policies: Dict[str, Policy],
episode: Episode,
env_index: int,
**kwargs
):
# Get logps.
logps = episode.last_extra_action_outs_for()
episode.user_data["logps"].append(logps)
def on_episode_end(
self,
*,
worker: RolloutWorker,
base_env: BaseEnv,
policies: Dict[str, Policy],
episode: Episode,
env_index: int,
**kwargs
):
# Make logps available in custom metrics for each episode.
mean_logps = np.mean(episode.user_data["logps"])
episode.custom_metrics["logps"] = mean_logps
# In your config, make sure to set up your custom callback class from above:
config["callbacks"] = MyCallbacks
2 Likes
Hey @sven1977,
Yes this gives me the most recent action prob, logp, logits (action dist inputs) and vf prediction
Could I also use the logits (action dist inputs) to get probs of all actions and not only of the sampled action (or at least logps)? My policy has a categorical action dist.
Yes, you can use policy.dist_class
to get your class and then use action_dist_inputs
to get all the logp
s.
1 Like