My observation is a Dict
{
'observation': Dict({ .. dictionary with observations }),
'mask': Box() # Mask for action masking
}
and I have a custom model
1 class DQNModel(TFModelV2):
2
3 def __init__(self,
4 obs_space: Space,
5 act_space: Space,
6 num_outputs: int,
7 model_config: Dict,
8 name: str):
9 super().__init__(obs_space, act_space, num_outputs, model_config, name)
10 orig_space = getattr(obs_space, "original_space", obs_space)
11
12 self.internal_model = FullyConnectedNetwork(
13 orig_space['observation'],
14 act_space,
15 num_outputs,
16 model_config,
17 name + '_internal',
18 )
19
20 def forward(self, input_dict, state, seq_lens):
21 action_mask = input_dict['obs']['mask']
22 logits, _ = self.internal_model({'obs': input_dict['obs']['observation']})
23
24 # Transform 0s in mask into -inf (using tf.float32.min to avoid NaNs)
25 inf_mask = tf.maximum(tf.math.log(action_mask), tf.float32.min)
26 masked_logits = logits + inf_mask
27
28 return masked_logits, state
29
30 def value_function(self):
31 return self.internal_model.value_function()
adapted from the example action_mask_model.py.
However, orig_space['observation']
in line 10 is of type gym.spaces.Dict
and it has no shape
attribute, so the FullyConnectedNetwork
__init__
at line 12 will raise an error.
My understanding is that I need to get the space in orig_space['observation']
, preprocess it so that it will be unflattened and pass it to the model in line 12. Is it correct? How can I do this?