How does RLlib handle sequences?

For something like the GTrXL implementation, how does the model receive input sequences (histories of input at previous times) to attend over? Are they already provided as part of the input dict, or does GTrXL save past inputs to re-use for Multi-Headed Attention?

My environment produces observations of Box (123,), yet the input dict receives observations of Box(32,123). Is this the batch size or sequence size? I set the max_seq_len of my model to 50, so I’m not sure where the 32 is coming from, which is why I am confused about how RLlib passes sequences/histories to the model.

1 Like

Hey @jarbus . Great question. The 32 is just a dummy batch, we pass through the loss to check, whether everything has compiled correctly and to check, which fields of the batch your model/loss/exploration/etcc accesses.

During the actual sampling (inference) and learning, RLlib does the following:

  1. Inference: RLlib will provide the GTrXL model with an internal state input of: [B ← vectorized env size (`num_envs_per_worker`), [`attention_memory_inference`=100], [attention_dim = size of each memory vector]] x attention_num_transformer_units
    Note that max_seq_len does not play a role for inference (the “seq-len” here is always just 1).

  2. Training: [B (`train_batch_size|sgd_minibatch_size`), [`attention_memory_training`=50], [attention_dim]] x attention_num_transformer_units
    Note that max_seq_len here determines the “repeat” inside the train batch. So e.g. if max_seq_len=10 and attention_memory_training=50, you would see chunks of len 10 in your train batch, but each chunk would still have a memory input with a [B, T=50, …] shape. Whether such a setup would make sense is a different question.

I would always recommend setting your values to:
max_seq_len=attention_memory_inference=attention_memory_training=[somewhere between 20 and 200].

2 Likes