My environment have a spaces.Dict
for the action spaces.
I wrote a custom DistributionWrapper
, returning a dict
, and it looks all fine :
class CustomDistribution(TorchDistributionWrapper):
def sample(self):
idx_dist = self._idx_distribution()
idx = idx_dist.sample()
type_dists = self._type_distribution(idx)
actions = [d.sample() for d in type_dists]
self._action_logp = idx_dist.logp(idx) + sum(d.logp(a) for d, a in zip(type_dists, actions))
return {
"idx": idx,
"p1": actions[0],
"p2": actions[1],
"p3": actions[2],
}
But I’m having a problem for the implementation of logp()
method… This method takes actions as input, and it’s a Tensor format, not a dictionary. The tensor has shape [batch, 4]
(which is normal, I have 4 actions).
Does someone know where I can find the code dict
→ tensor
?
Because I’m having the following error :
ValueError: The value argument must be within the support
Which means the order of my action is messed up !