Right way to use tuple action space

Hi there, I’m using a custom environment with a tuple (gym space) action space.
TL;DR - I’m having trouble about how should I construct the output of the model from the forward function.

my action space is defined as:

Tuple((DiscreteWithDType(9, dtype=np.uint8), DiscreteWithDType(9, dtype=np.uint8)))

And I don’t know how to output the value in the forwad pass, is there some example to look at?

1 Like

I know it’s not enough time since the last update but time is of the essence :frowning:

I just need an example of the shape and type of the output forward model for this kind of action space (any tupled action space will do).

Hi @Ofir_Abu,

Perhaps looking through this example will help you figure out what you need.


This function here should also help:

1 Like

Thank you! The first example helps me, but I specifically have trouble with a custom tf model, I don’t know how to define the type and shape of the forward pass.

I guess I will debug a simple case of non-costum model to understand it, but if someone has a reference that would be a great help :slight_smile:

@Ofir_Abu what is DiscreteWithDType? I’m not familiar with this. I don’t believe it’s a gym space…

Correct, but it’s a compatible wrapper to the official Discrete space of gym, basically has some extra dtype casting functions.

Does someone have an example of the forward function’s output in a similar case?

thanks again @mannyv !
Is there an easy way to how the output of the forward pass is constructed? specifically - how do I debug it?

Hey @Ofir_Abu , your Tuple space results in a MultiActionDistribution to be chosen by RLlib as the model’s output (the model parameterizes this distribution type and outputs an according number of nodes). The output values of the model are then split inside this distribution, according to the individual sub-spaces (2x DiscreteWDtype) and then actions will be sampled from these two spaces individually using the logits produced by your model.

You can debug into your forward pass by setting a breakpoint in e.g. rllib/models/torch/torch_action_dist::TorchMultiActionDistribution::sample() (of the respective tf version) AND setting local_mode=True in your call to ray.init().

Sounds great thank you for the explanation about the 2 distributions.
I will try now to debug it and edit the message with the results :slight_smile:

1 Like