Custom model for DQN

Using a custom model to use with DQN is really not straightforward from the examples.


My custom model is a subclass of DQNTorchModel : I customized forward(), get_state_value(), get_q_value_distributions(), but when I train, I’m having the following error :

ValueError: The parameter logits has invalid values


I looked at the example parametric_actions_cartpole.py to try to understand what’s going on, but I’m even more confused.

In this example, a custom model, TorchParametricActionsModel is used.
This model is a subclass of DQNTorchModel, but only forward() is overwriten (not get_q_value_distributions() or get_state_value()).


I thought in DQN action is selected by taking the action with the maximum Q-value ?
So how can the action be selected if get_q_value_distributions() is not implemented ?

Hi @Astariul,

It works because of inheritance. TorchParametricActionsModel is a subclass of TorchDQNModelwhich means TPAM will inherit all of the method implemented by TDM. The default implementation was compatible with the changes made in TPAM so they did not have to override them.

Hi @mannyv ,

Thanks for your fast answer !

I understand. But I’m still confused on the role of methods get_q_value_distributions(), get_state_value(), forward(), and value_function()

I thought that :

  • forward() should return a feature vector representing the observation.
  • get_q_value_distributions() should returns Q(s, a) for all a, given the feature vector.
  • get_state_value() should returns V(s), given the feature vector.
  • value_function() was not used for DQN.

But, when I look at the example, it seems that :

  • forward() returns Q(s, a) for all a.
  • value_function() returns V(s).

The class inherits get_q_value_distributions() and get_state_value(), but I don’t understand what are their roles in here, since the model already has methods to compute Q(s, a) and V(s).


I couldn’t find in the docs an explanation of the role of each of these methods… So if I want to write a custom model for DQN algo, which method should I overwrite and whate are their roles ?

After further debugging, the error I had was from Q-values being inf.

It was because the function reduce_mean_ignore_inf() was not ignoring -inf values, because I followed the example and did :

inf_mask = torch.clamp(mask.log(), FLOAT_MIN, FLOAT_MAX)

Changing the code to :

inf_mask = mask.log()

solved it.