LSTM/Attn wrappers take the policy output instead of the latent features

In the current implementation it seems like that if use_lstm (or use_attention) is true then it wraps the provided policy with the LSTM or Attention wrapper in a weird way.

During the forward pass the pseudo code is (see here):

# LSTM wrapper's forward
def forward(input):
    wrapped_out = wrapped.forward(input)
    return forward_rnn(wrapped_out)

Shouldn’t the recurrent module get the latent features instead of the output of the logits_branch? Or am I missing something and if the wrapper is used then the wrapped policy outputs the features? I.e. instead of doing this:

# Wrapped policy's forward
def forward(inputs):
    features = convs(inputs)
    return logits_branch(features)

Based on this paper and the Stable Baselines implementation it seems like that the latent state should be passed to the recurrent module (whichever is being used).

Hi @vakker00,

In the standard rllib models there is really no difference between latent feature layers and logot action layers other than where they occur in the architecture and their sizes. Latent layers being in the middle with sizes specified by fcnet_hiddens and logits being at the end with size determined by the action space. When you request a wrapper the model catalog will create and use a final layer in the wrapper.

The wrapper methods were added after the original models were written. When they were added they ket the original naming convention which is where the confusion comes from.

@mannyv thanks for the response, there is no confusion about the “how”. The question is rather about the “why”, and the theoretical consideration of where the input to the recurrent wrapper should be extracted from.

RLlib’s default implementation takes the output of the policy and feeds it to the recurrent layers.
The referenced paper and the SB implementation takes the hidden state of the policy (e.g. the output of a CNN encoder) and not the output (which typically has significantly lower information content).

So the question is: is there a publication that the RLlib implementation of recurrent policies is based on?


My understanding is that SB3 does not have LSTM support, I just did a quick search to see if that has changed and I could not find that it had. If I missed it please do point it out to me.

SB2 has LSTM support but as far as I can tell, looking at the code, the “default” method does it the same way as rllib does. It depends on if the user specifies provides a net_arch parameter or not:

If the user does not then it does it just like rllib does here:

If the user does then it is up to the discretion of the user:
for example you can do it with net_arch=[64,'lstm',...] or without net_arch=['lstm',...]:

I am not a heavy SB user so I may have gotten it wrong so please do correct any misunderstandings.

On more general comment. The forum has users with a very wide spectrum of RL experience from very experienced to oh man they asked me to do an RL project yesterday what is it? With this in mind I try to assume very little about their background when someone asks a question so that it might be more broadly useful.

@mannyv thanks for the reply. Yes, it’s only in SB2, the link in the original post points to that (also see the linked paper, that has a clearer illustration).

When you request a wrapper the model catalog will create and use a final layer in the wrapper.

To me it seems like that the section that you mentioned (with or without net_arch) clearly applies the RNN on the latent features BEFORE any of the policy or vf layers. I think this might also be the case for RLlib, but it’s slightly obfuscated.

I think what’s happening is that when the recurrent wrapper is used, the num_outputs that’s used to instantiate the wrapped class is None, so self._logits is also None and self.last_layer_is_flattened is False. This leads to the wrapped class’ forward function to use the return conv_out, state case instead of return logits, state (see here), i.e. it passes the latent vector to the wrapper and not the output of self._logits.

So if that’s correct, then I think it actually clarifies my main question:

Or am I missing something and if the wrapper is used then the wrapped policy outputs the features?

Yes, when the wrapper is used then the wrapped policy passes the extracted features to the wrapper.

Edit: one thing that might be a potential bug, that the post_fcnet_hiddens layers are not applied when the wrapper is used (_logits_branch should be constructed similarly to the wrapped class’ _logits). It was added recently to VisionNetwork (and I guess to others as well), and maybe that change was not implemented in the wrappers.


Looking at the cnn module in SB2. Doesn’t it pass the CNN features to a linear layer as the last operation?