For something like the GTrXL implementation, how does the model receive input sequences (histories of input at previous times) to attend over? Are they already provided as part of the input dict, or does GTrXL save past inputs to re-use for Multi-Headed Attention?
My environment produces observations of Box (123,), yet the input dict receives observations of Box(32,123). Is this the batch size or sequence size? I set the max_seq_len of my model to 50, so I’m not sure where the 32 is coming from, which is why I am confused about how RLlib passes sequences/histories to the model.
Hey @jarbus . Great question. The 32 is just a dummy batch, we pass through the loss to check, whether everything has compiled correctly and to check, which fields of the batch your model/loss/exploration/etcc accesses.
During the actual sampling (inference) and learning, RLlib does the following:
Inference: RLlib will provide the GTrXL model with an internal state input of: [B ← vectorized env size (`num_envs_per_worker`), [`attention_memory_inference`=100], [attention_dim = size of each memory vector]] x attention_num_transformer_units
Note that max_seq_len does not play a role for inference (the “seq-len” here is always just 1).
Training: [B (`train_batch_size|sgd_minibatch_size`), [`attention_memory_training`=50], [attention_dim]] x attention_num_transformer_units
Note that max_seq_len here determines the “repeat” inside the train batch. So e.g. if max_seq_len=10 and attention_memory_training=50, you would see chunks of len 10 in your train batch, but each chunk would still have a memory input with a [B, T=50, …] shape. Whether such a setup would make sense is a different question.
I would always recommend setting your values to: max_seq_len=attention_memory_inference=attention_memory_training=[somewhere between 20 and 200].
Hi!
I’m a pretty much new to Ray and RL and I am wondering about the influence of max_seq_len in GTrXl implementation. I was understanding that Transformers in general take a sequence of tokens as input, on which attention is performed. And so far I thought that max_seq_len was representing that input sequence length. But if I understand correctly the answer, during inference, GtrXl actually takes as input only the current observation (and a set of memory). So the sequence length is just one, and then the memory is added. Is that correct?