How are action computed from action_dist_inputs?

Hey community,

You can get action_dist_inputs from algo.compute_single_action(obs_vector, policy_id="1", full_fetch=True), but it’s not clear to me what happen to this action_dist_inputs to compute actual actions. Somehow, this is missing from the documentation and I haven’t been able to find it in the code.

Any help welcome :slight_smile:

Hi,

I actually ofund a way using Catalogs

You can retrieve your model catalog from your policy this way :slight_smile:


model_config = policy.model.config
catalog = model_config.catalog_class(model_config.observation_space, model_config.action_space, model_config.model_config_dict)

and then use it to compute actions from algo:

outputs = algo.compute_single_action(obs_vector, policy_id="prey", full_fetch=True)
action_dist_inputs = outputs[2]["action_dist_inputs"]

action_dist_class = catalog.get_action_dist_cls(framework="torch")
action_dist = action_dist_class.from_logits(action_dist_inputs)
actions = action_dist.sample().numpy()

So here the two last line is supposedly what happens after to the action_dist_inputs

The from_logits function is describe in this file : ray/rllib/models/torch/torch_distributions.py at master · ray-project/ray · GitHub