LSTM Auto Wrapper

Hi All,

Can someone clarify to me how the LSTM auto-wrapper and sequence length interacts with the built in PPO when combined with the curiosity module?

  1. If I have “max_seq_len”: 4" → does that mean that up to 4 timesteps (obs) will be passesed into the lstm?
  2. What is the exact flow here? → Does it go obs → LSTM → basePPO? OR obs → basePPO → LSTM?

Thanks!

  1. max _seq_len handles truncating back prop through time for the lstm. So the lstm will get inputs of size [B//4,4,f_in_size]

I dont think it will affect the curiosity module but I have not verified that.

  1. It is the second. Take a look at the second diagram in this thread.

“So the lstm will get inputs of size [B//4,4,f_in_size]”

Wait, in this case what does “B” and “f_in_size” stand for?

B is the batch size and f_in_size is the size of the inputs coming from the previous layer in the model. Usually this is the same as lstm_cell_size but it need not be.

Sorry to revive the thread again, but I just wanted to clarify another point. If the LSTM is getting inputs of the previous function and a slice of the input_obs/seq_len, then there is no timestep? Like it’s just a single layer of an LSTM of seq_len cells that feed into-each other the inputs. Therefore it does not get previous timesteps at all? (Unless I opt-in to provide the T-1 action and/or reward flags as True).

So in the case I have:

fcnet_hiddens: [64],
use_lstim: True,
max_seq_len: 8,
lstm_cell_size: 20

and observation of size 640 and action_space of 5. Then the flow becomes the following:
640 → PPO which returns 64 values → they get broken up into 8 pieces, each piece getting 6 values from the ppo (e.g. piece1 → cell1 → (output of previous cell of 20 + piece2 (of size 6) → cell2 → repeat for 8 cells) → linear mapping to action_space of 5?

Is that correct?

Hi @Denys_Ashikhin,

That is not correct.

Here is a way to think about it. We have to consider two different ways it is used.

  1. During rollouts (calling compute_actions…):

In this mode rllib interacts with the environments 1 step at a time. On each step it gets the observation from the environment and provides the policy with that observation (in your example that is size 64 from the embeddings produced by the fcnet) plus the state of the lstm at the previous timestep. In your example that is (20, the lstm_cell_size). The policy will use its model to compute the next set of actions to take and also the next state of the lstm having seen the old state plus the new observation.

RLLIB will store that new_state value and on the next step of the environment it will forward it along to the next call to compute_actions. In essence it is chaining the state from step to step. There is a method in the policy model called get_initial_state that will provide the starting state to use on the first step because there is no previous one to use. In rllib the default is all zeros.

In addition to chaning the state, rllib also saves that state in the sample batch to use during traning.

  1. During training (in the loss function)
    In this phase you have a sequence of transitions from possibly many episodes.

The first thing rllib will do is pad all of the episodes so that they are a multiple of max_seq_lens. First it chops the episodes into subsequences of at most max_seq_len steps then for any subsequence that is less than max_seq_len it pads it with zeros. For example an episode with 30 timesteps will be be broken up into 2 sequences. One with the first 20 ts and another with the 10 remaining ts plus 10 zeros.

[(o0-0,s0-0), (o0-1,s0-1), …, (o0-n,s0-n), (o2-0,s2-0), (o2-1,s2-1), …, (o2-n,s2-n), …]

After this rllib will pass all of those observations through the fcnet. Here time does not matter since the embedding network is not time dependent.

The outputs of the fcnet will be reshaped to be [batch, max_seq_len, emb_size]
For your example this is [Batch_size//20 x 20 x 64]

The outputs from the lstm will then be reshaped back to [BS*max_seq_len,lstm_cell_size] then be passed through the rest of the networks and action distributions.

One last detail is that in the loss function the timesteps that were padded are zeroed out so that we do not have dummy observations contributing to the loss.

Hopefully this was not too confusing.

1 Like

Wow, thanks for that massive write-up. Okay, I think I understood the rollouts phase pretty well. I also understood the training part ok. My only question is why its done the way its done.

From my basic understanding of LSTM’s, since during rollouts it only gets the values from the previous cell output (and ofc the embedding from FCNET), which are what it deemed the most important bits of information that are crucial to be passed on to the next observation/cell for action calculation, it’s really only looking one time-step back directly. However, the way LSTM’s are, for all we know, each cell might pass values from the 3rd timestep all the way to the end, so in that way, we have an LSTM that spans the length of our entire episode. So if we have 100 steps in an episode, it’s kind-of like we had a 100 long LSTM (along with the FCNET).

Furthermore, even if my FCNET is 64, I can specify a custom “output_size” to say 1024, in which case it will take my 640 OBS → 64 → 1024 + (20 LSTM Cell size) = all of this is combined in an LSTM cell => to spit out an action.
Let me know if my understanding of 1 is good now.

For the training phase, I’ll be honest that went over my head for why its done the way its done. I would really, really appreciate if you good link some articles/paper explaining the rationale of this training setup.

Moreover, I would also REALLY appreciate if you could give me some rough tips on how sequence_length affects final results, when you would use what lengths. Same of cell size. If there are any papers/articles that also look into the more practical applications/results of that I would love to read them as well.

Once more, thank you for all your help so far!