Ppo centralized critic model with action masking

I have a system heavily inspired by centralized_critic_model example.
I have been reading some action masking examples, and i want to implement it on my system. I think i understood the Env part i should modify, what i am unsure i should change is the model part, i see in examples that i should access input_dict[“obs”][“action_mask”] and modify logits according to what i want to do with action masking.

My main doubts are:

@override(ModelV2)
def forward(self, input_dict, state, seq_lens):
return self.model.forward(input_dict, state, seq_lens)

this is forward method in centralized critic model (line 43), how could i modify it? i am not very practical with PublicAPI annotations and i don’t know where i should put my hands on.

Also i see that most of action masking examples rely on logits and state modifications, my implementation follows twostepgame example env which by default has state and logits set as false. Would that be a problem?

thanks in advance!

an update on my issue:
i kinda understood what is happening in forward, and that those lines i quoted are pointing to FullyConnectedNetwork’s forward, so i guess that i should do my customizations there.

Also, thinking about my implementation, i realized that my mask is the same as the observation!
I have Discrete(100) as Action Space and Discrete(100) as observation space, what i want is that my action is smaller than observation.
(suppose my obs is 70, my action can be any X number where 0<X<70)
how can i achieve that?

I am trying this way

model_out, self._value_out = self.base_model(input_dict[“obs_flat”])
inf_mask = tf.maximum(tf.math.log(input_dict[“obs_flat”]), tf.float32.min)
masked_logits = model_out + inf_mask
return masked_logits, state

basically i am copying what is happening in action masking example, but putting my observation in inf mask instead of the action mask.
what i am getting is that every action equals to the observation ( if i give an obs=10, i get an action =10)
what am i doing wrong here? ( i think using Discrete in inf_mask = … line might be the isssue but i am not sure)

i think i got this. If anyone is interested in my solution:
the action mask had to be a Box() with basically 1 or 0 flags of the available actions.
what i did was to pass a Box with as many ones as the max value i wanted for the action
that looks like to work!