How to choose the action dist for a custom model with a Tuple action space?

How severe does this issue affect your experience of using Ray?

  • High: It blocks me to complete my task.

I am writing a custom model for a custom environment that has a Tuple action space. If I understood correctly, the trainer needs to use the appropriate action distribution to treat the model output, which would be a tuple of tensors.

I’d expect the trainer to create the action distribution automatically, since it already knows the action space from the env, but that’s not the case (it uses the default one). And I don’t see where in the trainer config I could set up the action dist class.

Hi @erickrf,

I’ve never used a Tuple action space, so I’m not sure how it works.
I’ve previously worked with a custom environment that had tuple actions. In that case, what I did was just enumerate the tuples with a dictionary (e.g. {1: (1,1), 2: (2,1), 3: (1,2), and so on}) and then use as action space a gym.spaces.Discrete(len(dictonary)).

it will use default settings if you don’t customize it. for discrete actions it will use categorical options,
for Box actions it will use Normal distribution. for tuple action spaces it will use multiactiondistribution class which again use default settings. if you want to change default distribution you can subclass TorchMultiActionDistribution and write your own.
for e.g if you have 2 Box and 1 discrete actions in your Tuple and you want to assign 2 Beta dist and 1 categorical dist to them. you can write a class like this:

from ray.rllib.models.torch.torch_action_dist import TorchDistributionWrapper,TorchMultiActionDistribution,TorchBeta,TorchCategorical
from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space
import tree 
class My_betadist(TorchMultiActionDistribution,TorchDistributionWrapper):
  def __init__(self, inputs, model, *, child_distributions, input_lens,
        # super().__init__(inputs, model, child_distributions, input_lens,action_space)
        child_distributions = [TorchBeta,TorchBeta,TorchCategorical]
        if not isinstance(inputs, torch.Tensor):
            inputs = torch.from_numpy(inputs)
            if isinstance(model, TorchModelV2):
                inputs =
        # print(inputs,'aaa')        
        TorchDistributionWrapper.__init__(self,inputs, model)        
        self.action_space_struct = get_base_struct_from_space(action_space)
        self.input_lens = tree.flatten(input_lens)
        flat_child_distributions = tree.flatten(child_distributions)
        split_inputs = torch.split(inputs, self.input_lens, dim=1)
        self.flat_child_distributions = tree.map_structure(
            lambda dist, input_: dist(input_, model), flat_child_distributions,

and remember to register it after init() like this:
ModelCatalog.register_custom_action_dist(“My_betadist”, My_betadist)
also you should pass action dist in your model config like this:‘custom_action_dist’:‘My_betadist’,

Just to be clear, I meant that I expected the trainer to use some action distribution capable of understanding my model’s output, which is a tuple of tensors.

Anyway, looking at your example, I see that the model is not supposed to return a tuple of logit tensors but a single one. So my question now is, what should be the model logits in case of an action space like Tuple(Discrete(n), Discrete(m))? Or for that matter, what about a dictionary action space?

model returns a flatten tensor always and that flatten tensor then will be used to fit the distributions(so it doesn’t matter you are using Dict or Tuple, all would be flatten and the only thing that matters is which distribution you are using for each action). number of logits are depended on the which distribution you are using. for Discrete(m) you have m logits(for normal dist. you have 2 logits every actions, one is mu and another is sigma). so for your Tuple actions, you will have “m + n” logits in return of Forward method of your model (beside states for RNN etc).

Oh thanks, that’s pretty simple after all, just a matter of[logits1, logits2], dim=-1)