How severe does this issue affect your experience of using Ray?
Medium: It contributes to significant difficulty to complete my task, but I can work around it.
Hello!
I want to implement autoregressive action space with new RLModule stack.
I have implemented the next class to solve my task in multidiscrete way but I think autoregressive actions would perform better.
class PPOModule(PPOTorchRLModule):
def __init__(self, config: RLModuleConfig):
super().__init__(config)
self.config = config
def setup(self):
self.encoder = PPOEncoder(...)
self.vf = PPOVf(...)
self.pi = PPOPi(...)
Distribution.set_default_validate_args(False) # https://discuss.pytorch.org/t/distributions-categorical-fails-with-constraint-simplex-but-manual-check-passes/163209 error
self.action_dist_cls = TorchMultiCategorical.get_partial_dist_cls(...)
How should I implement new TorchCategorical child class by this example?
a1_vec = tf.expand_dims(tf.cast(a1, tf.float32), 1)
_, a2_logits = self.model.action_model([self.inputs, a1_vec])
a2_dist = Categorical(a2_logits)
return a2_dist
@staticmethod
def required_model_output_shape(action_space, model_config):
return 16 # controls model output feature vector size
class TorchBinaryAutoregressiveDistribution(TorchDistributionWrapper):
"""Action distribution P(a1, a2) = P(a1) * P(a2 | a1)"""
def deterministic_sample(self):
# First, sample a1.
a1_dist = self._a1_distribution()
a1 = a1_dist.deterministic_sample()
# Sample a2 conditioned on a1.
a2_dist = self._a2_distribution(a1)
a2 = a2_dist.deterministic_sample()
How should I access the model in this Distribution and what functions should I implement?
Thank you!
I don’t know if it is a correct solution but it works for me (maybe)
class TorchAutoregressive(Distribution):
_model: nn.Module
@staticmethod
def set_model(model):
TorchAutoregressive._model = model
...
class PPOModule(PPOTorchRLModule):
...
def get_train_action_dist_cls(self) -> Type[Distribution]:
self.action_dist_cls.set_model(self.pi_a2)
return self.action_dist_cls
def get_exploration_action_dist_cls(self) -> Type[Distribution]:
self.action_dist_cls.set_model(self.pi_a2)
return self.action_dist_cls
def get_inference_action_dist_cls(self) -> Type[Distribution]:
self.action_dist_cls.set_model(self.pi_a2)
return self.action_dist_cls