[RLlib] Restoring a GTrXLNet or use_attention=True fails

When i try to restore a model with use_attention=True i get:

ValueError: Must pass in RNN state batches for placeholders [<tf.Tensor ‘default_policy/Placeholder:0’ shape=(?, ?, 64) dtype=float32>], got []

when I try:

state = [np.zeros([64], dtype=np.float32)]

while not done:
action = agent.compute_action(obs, state)
obs, reward, done, info = env.step(action)

Cannot feed value of shape (1, 64) for Tensor ‘default_policy/Placeholder:0’, which has shape ‘(?, ?, 64)’

how can I create an empty state with shape (?,?,64)?

np.zeros((0,0,64)) does not work.

Hey @digi604 , great question and sorry for the delay, which was caused by the question being “uncategorized”. It helps if you set a category (e.g. “RLlib”) when you post a new question. That way, we’ll find it more easily and can assign the right person to answer it.

Did you take a look at our attention net example script here, where we do exactly that (pass in states “manually” in an env loop)?

ray.rllib.examples.attention_net.py

Look at the commented out section:

    # To run the Trainer without tune.run, using the attention net and
    # manual state-in handling, do the following:

    # Example (use `config` from the above code):
    # >> import numpy as np
    # >> from ray.rllib.agents.ppo import PPOTrainer
    # >>
    # >> trainer = PPOTrainer(config)
    # >> num_transformers = config["model"]["attention_num_transformer_units"]
    # >> env = RepeatAfterMeEnv({})
    # >> obs = env.reset()
    # >> init_state = state = [
    # ..     np.zeros([100, 32], np.float32) for _ in range(num_transformers)
    # .. ]
    # >> while True:
    # >>     a, state_out, _ = trainer.compute_action(obs, state)
    # >>     obs, reward, done, _ = env.step(a)
    # >>     if done:
    # >>         obs = env.reset()
    # >>         state = init_state
    # >>     else:
    # >>         state = [
    # ..             np.concatenate([state[i], [state_out[i]]], axis=0)[1:]
    # ..             for i in range(num_transformers)
    # ..         ]