That is not correct.
Here is a way to think about it. We have to consider two different ways it is used.
- During rollouts (calling compute_actions…):
In this mode rllib interacts with the environments 1 step at a time. On each step it gets the observation from the environment and provides the policy with that observation (in your example that is size 64 from the embeddings produced by the fcnet) plus the state of the lstm at the previous timestep. In your example that is (20, the lstm_cell_size). The policy will use its model to compute the next set of actions to take and also the next state of the lstm having seen the old state plus the new observation.
RLLIB will store that new_state value and on the next step of the environment it will forward it along to the next call to compute_actions. In essence it is chaining the state from step to step. There is a method in the policy model called get_initial_state that will provide the starting state to use on the first step because there is no previous one to use. In rllib the default is all zeros.
In addition to chaning the state, rllib also saves that state in the sample batch to use during traning.
- During training (in the loss function)
In this phase you have a sequence of transitions from possibly many episodes.
The first thing rllib will do is pad all of the episodes so that they are a multiple of max_seq_lens. First it chops the episodes into subsequences of at most max_seq_len steps then for any subsequence that is less than max_seq_len it pads it with zeros. For example an episode with 30 timesteps will be be broken up into 2 sequences. One with the first 20 ts and another with the 10 remaining ts plus 10 zeros.
[(o0-0,s0-0), (o0-1,s0-1), …, (o0-n,s0-n), (o2-0,s2-0), (o2-1,s2-1), …, (o2-n,s2-n), …]
After this rllib will pass all of those observations through the fcnet. Here time does not matter since the embedding network is not time dependent.
The outputs of the fcnet will be reshaped to be [batch, max_seq_len, emb_size]
For your example this is [Batch_size//20 x 20 x 64]
The outputs from the lstm will then be reshaped back to [BS*max_seq_len,lstm_cell_size] then be passed through the rest of the networks and action distributions.
One last detail is that in the loss function the timesteps that were padded are zeroed out so that we do not have dummy observations contributing to the loss.
Hopefully this was not too confusing.