[rllib] SampleBatch "state_in_0" dimension shorter than expected

What is your max_seq_len? Before passing the model into forwad_rnn two things happen. 1. The data in the sample batch is padded so that all inputs are the same size as your longest episode. So if you had sequences of length [5,12,10,4] they would all be padded with zeros to be 12 steps long. You would end up with a total of12 * 4 timesteps in your sample batch.

They are also shortened to be no larger than max sequence length. Let’s say your max_seq_len in the model config was 20. Your state in would be of size [4,cell_size] since from the rnn perspective you do not need to truncate backprop through time. This is likely what you are seeing. When you are using an rnn you only get the initial state of the sequence. The other states are generated internally by the rnn logic. If your max sequence length was 5 in the example above your would likely have seq_lens of [5,5,5,2,5,5,4] , a sample batch that was padded to 7*5 and a state_in with a shape of [7,cell_size]

The other thing to keep in mind is that there are several passes though the models with dummy data before training starts to calculate view requirements and other values needed by compute action and the loss functions. I have found that sometimes those passes have shapes that I never see during actual traing.

3 Likes