RLModule with autoregressive actions

How severe does this issue affect your experience of using Ray?

  • Medium: It contributes to significant difficulty to complete my task, but I can work around it.

Hello!
I want to implement autoregressive action space with new RLModule stack.
I have implemented the next class to solve my task in multidiscrete way but I think autoregressive actions would perform better.

class PPOModule(PPOTorchRLModule):
    def __init__(self, config: RLModuleConfig):
        super().__init__(config)
        self.config = config
    
    def setup(self):
        self.encoder = PPOEncoder(...)
        self.vf = PPOVf(...)
        self.pi = PPOPi(...)
        Distribution.set_default_validate_args(False) # https://discuss.pytorch.org/t/distributions-categorical-fails-with-constraint-simplex-but-manual-check-passes/163209 error
        self.action_dist_cls = TorchMultiCategorical.get_partial_dist_cls(...)

How should I implement new TorchCategorical child class by this example?

How should I access the model in this Distribution and what functions should I implement?
Thank you!

I don’t know if it is a correct solution but it works for me (maybe)

class TorchAutoregressive(Distribution):
    _model: nn.Module

    @staticmethod
    def set_model(model):
        TorchAutoregressive._model = model
    
    ...



class PPOModule(PPOTorchRLModule):
    
    ...

    def get_train_action_dist_cls(self) -> Type[Distribution]:
        self.action_dist_cls.set_model(self.pi_a2)
        return self.action_dist_cls

    def get_exploration_action_dist_cls(self) -> Type[Distribution]:
        self.action_dist_cls.set_model(self.pi_a2)
        return self.action_dist_cls

    def get_inference_action_dist_cls(self) -> Type[Distribution]:
        self.action_dist_cls.set_model(self.pi_a2)
        return self.action_dist_cls