Value of num_outputs of DQNTrainer

How severe does this issue affect your experience of using Ray?

  • High: It blocks me to complete my task.

I have implemented my own gym environment and I am now creating my custom NN. My action space and observation space follow below:

h, w, k = 4, 8, 3
action_space = gym.spaces.Discrete(h*w)
observation_space = gym.spaces.Dict({"array1": gym.spaces.Box(low=np.float32(0),
                                                              high=np.float32(1),
                                                              shape=[h,w],
                                                              dtype=np.float32),
                                     "array2": gym.spaces.Box(low=np.zeros([k], dtype=np.float32),
                                                              high=np.array([h, w, 1], dtype=np.float32),
                                                              dtype=np.float32)})

print(action_space, observation_space["array1"].shape, observation_space["array2"].shape)

Discrete(32) (4, 8) (3,)

Since the action space is Discrete(32), I expected num_outputs to equal 32. However, this variable is equal to 256. Below is a short code that can be used to reproduce this. In this code, I am using the CartPole environment.

import torch as th
import torch.nn as nn

import ray
from ray.rllib.agents import dqn
from ray.rllib.models import ModelCatalog
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2


class CustomTorchModel(TorchModelV2):
    def __init__(self, obs_space, action_space, num_outputs, model_config, name):
        super().__init__(obs_space, action_space, num_outputs, model_config, name)
        
        print(obs_space.shape, action_space, num_outputs, name)
        print()
        
    def forward(self, input_dict, state, seq_lens):
        return th.zeros([32,256]), state


ModelCatalog.register_custom_model("my_torch_model", CustomTorchModel)

ray.init(ignore_reinit_error=True)
trainer = dqn.DQNTrainer(env="CartPole-v0",
                         config={"framework": "torch",
                                 "model": {"custom_model": "my_torch_model",
                                           "custom_model_config": {}}})

(4,) Discrete(2) 256 q_func
(4,) Discrete(2) 256 target_q_func

So the variable num_outputs is again equal to 256. I would like to hear why this is the case and where I can find documentation on this.

2 Likes

Hey @TedDeVriesLentsch , this is a very good question! Unfortunately, RLlib is a little messy in making clear, what goes on under the hood, when constructing the DQN model from a user-provided custom model.
For DQN, the standard procedure is as follows (you can check/debug this in rllib/agents/dqn/dqn_torch_policy.py::build_q_model_and_distribution):

  • Determine num_outputs: This is the number of nodes that your custom model should use as the size of its output layer (256 by default).
  • The reason for why num_outputs here is not the action space (2 in this case), is that for DQN, the required output heads (advantage-head AND value-head) must be created after(!) your custom architecture. These two heads must be present for the DQN loss term (dueling=True) to work. Their architecture is a simple MLP, defined by the hiddens parameter in your config ([256] by default).
  • So basically, the final architecture is as follows:
[your custom model (should output 256 nodes)] --> [ [256, 2] MLP Advantage head (2 outputs due to action space)] 
                                              \-> [ [256, 1] MLP Value head (1 output: the value)]

Now your custom model should respect the observation space (4, ), which it currently doesn’t (it outputs a constant value regardless of the incoming observation). But I think this was probably just for demonstration purposes :slight_smile:

3 Likes

To add to @sven1977 answer, if you want to have full control of the model I suggest you extend DistributionalQTFModel from distributional_q_tf_model.py. In the __init__() you can see how additional layers might be created on top of your model output (model_out).

2 Likes

Hey @sven1977,

Thank you for the answer! Thanks to your explanation, it is clear to me now. Indeed, the torch model I provided was just for demonstration.