Hi,
I want to add masking to my custum autoregressive ActionDistribution but I don’t see any way to get the Mask. Normally I provide the mask from the env as an input to the custom model. But now with the sampling in the ActionDistribution, I don’t know how to get the mask and add it to the action_head. Is there any sample or idea how to do it?
class CustomMultiCategorical(MultiCategorical):
def __init__( self, inputs: List[TensorType], model: ModelV2, input_lens: Union[List[int], np.ndarray, Tuple[int, ...]] = [VOCAB_NUM_CLASSES] * MAX_SENTENCE_LEN, action_space=None, *args, **kwargs ): # skip TFActionDistribution init ActionDistribution.__init__(self, inputs, model) self.cats = [] self.sampled_cats = [] batch_size = tf.shape(inputs)[0] last_cat_res = tf.fill((batch_size, 1), VOCAB_NUM_CLASSES) decoder_states = (None, None) for i, _ in enumerate(tf.split(inputs, input_lens, axis=1)): last_cat, decoder_states = self._get_autoregressive_categorical(inputs, model, last_cat_res, decoder_states) self.cats.append(last_cat) last_cat_res = last_cat.sample() self.sampled_cats.append(last_cat_res) self.action_space = action_space if self.action_space is None: self.action_space = gym.spaces.MultiDiscrete( [c.inputs.shape[1] for c in self.cats] ) self.sample_op = self._build_sample_op() self.sampled_action_logp_op = self.logp(self.sample_op) def _get_autoregressive_categorical(self, input_, model, last_cat_res, decoder_states): logits, state_h, state_c = model.nlp_action_head(input_, last_cat_res, decoder_states[0], decoder_states[1]) return Categorical(logits, model), (state_h, state_c) @staticmethod @override(ActionDistribution) def required_model_output_shape( action_space: gym.Space, model_config: ModelConfigDict ) -> Union[int, np.ndarray]: # Int Box. if isinstance(action_space, gym.spaces.Box): assert action_space.dtype.name.startswith("int") low_ = np.min(action_space.low) high_ = np.max(action_space.high) assert np.all(action_space.low == low_) assert np.all(action_space.high == high_) return np.prod(action_space.shape, dtype=np.int32) * (high_ - low_ + 1) # MultiDiscrete space. else: # nvec is already integer, so no casting needed. return np.sum(action_space.nvec)