- High: It blocks me to complete my task.
I’m trying to modify the example GTrXLNet
to use an action mask, but my first step is to just get the example model working. The example shows how to use the model with Tuner.fit
, but not alone.
I’m using the below code and getting the below error message. The best I can tell, I think the issue is in my values for state
and/or seq_lens
.
model.get_initial_state()
returns an empty list. This seems like it should be an error, but the Ray example uses this class just fine, so I’m not sure why that method returning an empty list should work.
What am I misunderstanding about using this class?
import gymnasium as gym
from numpy import array
from ray.rllib.examples.env.random_env import RandomEnv
from ray.rllib.models.torch.attention_net import GTrXLNet
from torch import tensor
env2 = RandomEnv(
{
"observation_space": gym.spaces.Box(0, 1, shape=[4]),
"action_space": gym.spaces.MultiDiscrete([2]),
}
)
env2.reset()
# %% Ray example mode
model = GTrXLNet(
observation_space=env2.observation_space,
action_space=env2.action_space,
num_outputs=2,
model_config={"max_seq_len": 10},
name="foo",
)
[logits, state] = model.forward(
input_dict={"obs": tensor(env2.observation_space.sample())},
state=model.get_initial_state(),
seq_lens=tensor(array([1])),
)
Error:
Traceback (most recent call last):
File "/home/user/clockpunch/tests/nets/test_gtrxl.py", line 59, in <module>
[logits, state] = model.forward(
File "/home/user/anaconda3/envs/punch/lib/python3.10/site-packages/ray/rllib/models/torch/attention_net.py", line 223, in forward
all_out = self.layers[i](all_out, memory=state[i // 2])
IndexError: list index out of range
Here is an excerpt fromGTrXLNet
near the error, with values of relevant variables:
all_out = observations
memory_outs = []
for i in range(len(self.layers)):
# MHA layers which need memory passed in.
if i % 2 == 1:
all_out = self.layers[i](all_out, memory=state[i // 2]) # error here
else:
all_out = self.layers[i](all_out)
memory_outs.append(all_out)
# i = 1
# all_out.shape = torch.Size([1, 64])
# state = []
# len(memory_outs) = 1
# memory_outs[0].shape = torch.Size([1, 64])