[rllib] How to implement this model in RLlib?

Its final CNN layer is a 6x6 image that is not 1x1, and built-in models need this value is 1x1.

Note that for the vision network case, you’ll probably have to configure conv_filters, 
if your environment observations have custom sizes. 
For example, "model": {"dim": 42, "conv_filters": [[16, [4, 4], 2], [32, [4, 4], 2], [512, [11, 11], 1]]} 
for 42x42 observations. Thereby, always make sure that the last Conv2D output has an output 
shape of [B, 1, 1, X] ([B, X, 1, 1] for PyTorch), where B=batch and X=last Conv2D layer’s 
number of filters, so that RLlib can flatten it. An informative error will be thrown if this is not
 the case.

@bug404 This doesn’t answer your question directly, but here is an example where you can define your own model in Keras: ray/custom_keras_model.py at master · ray-project/ray · GitHub

2 Likes

Hey @bug404 , thanks for the question and the detailed image of your model. This could be really useful for other users that stumble on this question. Thanks @RickLan for the answer as well! :slight_smile:

Yeah, conv stacks need to be setup with care as wrong settings will lead to obscure shape errors inside the code. RLlib tries to catch these and output a more meaningful error.
The problem is that I don’t think RLlib would have a built-in solution for a model like yours due to the flattening step (I admit, we should fix that).

I did figure out a proper conv setup that would point you in the right direction for the custom model that’s needed here:
[[32, [11, 11], 2], [32, [8, 8], 2], [32, [6, 6], 2], [32, [6, 6], 1]]
You need 4 Conv2D layers, outputting each 32 filters with different kernels (11x11, 8x8, 6x6, and 6x6, with strides of 2, 2, 2, and 1). Then your final output would be 32x6x6, which flattens to exactly 1152 node, then you add another 512 dense layer + your action output.

there should be an example of making the an atari like model using rllib’s built-in-models somewhere in rllib, since it involves using various config settings such as conv_filters, hiddens which new users find it very hard to begin with. rllib is a great swiss-army knife but the learning curve is too steep.

1 Like