I add the class definition down below. However, my intuition is that this falls outside of my class since the obs_space
I describe comes directly from what my class receives in its __init__
.
class DQNModel(DistributionalQTFModel):
"""
Custom DQN model that implements action masking by setting the Q values of unavailable
actions with tf.float32.min.
"""
def __init__(
self,
obs_space: GymDict,
action_space: Discrete,
num_outputs: int,
model_config: ModelConfigDict,
name: str,
q_hiddens=(256,),
dueling: bool = False,
num_atoms: int = 1,
use_noisy: bool = False,
v_min: float = -10.0,
v_max: float = 10.0,
sigma0: float = 0.5,
add_layer_norm: bool = False,
):
super(DQNModel, self).__init__(obs_space, action_space, num_outputs, model_config, name,
q_hiddens, dueling, num_atoms, use_noisy, v_min, v_max,
sigma0, add_layer_norm)
# Ensure that the configuration does not make RLlib add layers on top of our model
self.validate_config(num_atoms, q_hiddens, dueling, use_noisy, model_config)
# The internal model takes the observation as input and produces the Q values
self.internal_model = FullyConnectedNetwork(
obs_space, action_space, action_space.n, model_config, f'{name}_internal')
def forward(self,
input_dict: Dict[str, TensorType],
state: List[TensorType],
seq_lens: TensorType) -> (TensorType, List[TensorType]):
"""
Compute the (masked) Q values.
:param input_dict: Dictionary with keys 'obs' (the original observation dictionary) and
'obs_flat' (the flattened observation)
:param state: Not used in fully connected model
:param seq_lens: Not implemented
:return q_values: Tensor with (masked) Q values
:return state: The state received as input (not used here)
"""
observation = input_dict['obs_flat']
q_values, _ = self.internal_model({'obs': observation})
if 'mask' in input_dict['obs']:
mask = input_dict['obs']['mask']
inf_mask = tf.maximum(tf.math.log(mask), tf.float32.min)
q_values = q_values + inf_mask
return q_values, state
def validate_config(self, num_atoms: int, q_hiddens: Tuple, dueling: bool, use_noisy: bool,
model_config: Dict):
"""
Check that the received config does not make RLlib add layers on top of our model,
since we want to have full control of it and output q_values directly in forward().
"""
assert num_atoms == 1, 'Distributional DQN not implemented yet'
assert not q_hiddens, 'Additional Q layers are not supported with action masking'
assert not dueling, 'Dueling architecture not supported with action masking'
assert not use_noisy, 'Additional noisy layers not supported with action masking'
assert not model_config['post_fcnet_hiddens'], \
'All hidden layers must be specified in fcnet_hiddens'
assert not model_config['no_final_linear'], 'Do not skip final linear layer in DQN.'
def value_function(self) -> TensorType:
raise NotImplementedError('Not used in DQN')
def import_from_h5(self, h5_file: str) -> None:
raise NotImplementedError('Import from h5 not implemented.')