Valid inputs for `state`, `seq_lens` in GTrXLNet

  • High: It blocks me to complete my task.

I’m trying to modify the example GTrXLNet to use an action mask, but my first step is to just get the example model working. The example shows how to use the model with, but not alone.

I’m using the below code and getting the below error message. The best I can tell, I think the issue is in my values for state and/or seq_lens.

model.get_initial_state() returns an empty list. This seems like it should be an error, but the Ray example uses this class just fine, so I’m not sure why that method returning an empty list should work.

What am I misunderstanding about using this class?

import gymnasium as gym
from numpy import array
from ray.rllib.examples.env.random_env import RandomEnv
from ray.rllib.models.torch.attention_net import GTrXLNet
from torch import tensor

env2 = RandomEnv(
        "observation_space": gym.spaces.Box(0, 1, shape=[4]),
        "action_space": gym.spaces.MultiDiscrete([2]),
# %% Ray example mode
model = GTrXLNet(
    model_config={"max_seq_len": 10},

[logits, state] = model.forward(
    input_dict={"obs": tensor(env2.observation_space.sample())},


Traceback (most recent call last):
  File "/home/user/clockpunch/tests/nets/", line 59, in <module>
    [logits, state] = model.forward(
  File "/home/user/anaconda3/envs/punch/lib/python3.10/site-packages/ray/rllib/models/torch/", line 223, in forward
    all_out = self.layers[i](all_out, memory=state[i // 2])
IndexError: list index out of range

Here is an excerpt fromGTrXLNet near the error, with values of relevant variables:

all_out = observations
memory_outs = []
for i in range(len(self.layers)):
     # MHA layers which need memory passed in.
     if i % 2 == 1:
        all_out = self.layers[i](all_out, memory=state[i // 2])  # error here
        all_out = self.layers[i](all_out)

# i = 1
# all_out.shape = torch.Size([1, 64])
# state = []
# len(memory_outs) = 1
# memory_outs[0].shape = torch.Size([1, 64])


It appears that GTrXLNet may not be intended to be used without being wrapped with AttentionWrapper.
AttentionWrapper.get_initial_states() returns a non-empty list:

    def get_initial_state(self) -> Union[List[np.ndarray], List[TensorType]]:
        return [
            for i in range(self.gtrxl.num_transformer_units)

So I modified my script above to use AttentionWrapper instead of GTrXLNet (note the extra args that need to be included, values taken from default config settings):

model = AttentionWrapper(
        "max_seq_len": 10,
        "attention_use_n_prev_actions": 0,
        "attention_use_n_prev_rewards": 0,
        "attention_dim": 64,
        "attention_num_transformer_units": 1,
        "attention_num_heads": 1,
        "attention_head_dim": 32,
        "attention_memory_inference": 50,
        "attention_memory_training": 50,
        "attention_position_wise_mlp_dim": 32,
        "attention_init_gru_gate_bias": 2.0,

But using this model results in an error:

AttributeError: 'AttentionWrapper' object has no attribute '_wrapped_forward'

Looking at AttentionWrapper.forward,

def forward(
        input_dict: Dict[str, TensorType],
        state: List[TensorType],
        seq_lens: TensorType,
    ) -> (TensorType, List[TensorType]):
        assert seq_lens is not None
        # Push obs through "unwrapped" net's `forward()` first.
        wrapped_out, _ = self._wrapped_forward(input_dict, [], None)

I realized that this model is also not intended to be used as a stand-alone model, but is supposed to wrap a different stand-alone model. So I’m worse off than I was before. I just want to be able to test GTrXLNet.forward() alone, without building an Algorithm and running train. The attention net example does not show how to do this, and as far as I’ve seen there are no examples that demonstrate using the forward method alone.

If anyone can point me to an example of using just the forward method on any model, or explain why this this is not an appropriate way to test a model, I would appreciate it.

After talking with a colleague, I’ve concluded that my attempted workflow of testing the .forward method in isolation is not good practice and have moved onto testing instantiated algorithms instead. Marking this solved.