Hey community,
You can get action_dist_inputs
from algo.compute_single_action(obs_vector, policy_id="1", full_fetch=True)
, but it’s not clear to me what happen to this action_dist_inputs to compute actual actions. Somehow, this is missing from the documentation and I haven’t been able to find it in the code.
Any help welcome
Hi,
I actually ofund a way using Catalogs
You can retrieve your model catalog from your policy this way
model_config = policy.model.config
catalog = model_config.catalog_class(model_config.observation_space, model_config.action_space, model_config.model_config_dict)
and then use it to compute actions from algo:
outputs = algo.compute_single_action(obs_vector, policy_id="prey", full_fetch=True)
action_dist_inputs = outputs[2]["action_dist_inputs"]
action_dist_class = catalog.get_action_dist_cls(framework="torch")
action_dist = action_dist_class.from_logits(action_dist_inputs)
actions = action_dist.sample().numpy()
So here the two last line is supposedly what happens after to the action_dist_inputs