Using trained policy with attention net reports assert seq_lens is not None error

Hello

After training successfully a policy using tune and rlib and with use_attention, I try to restore the policy so I can actually use it (inference? sorry I’m still grasping the concepts here :slight_smile: )

But I get an error due to the attention net and seq_lens not being None

I believe this is the relevant part of the code:

ppo_config = (  
    PPOConfig()
    .resources(num_cpus_per_worker=1) 
    .environment(env="ACustomGym", disable_env_checking=True)
    .rollouts( preprocessor_pref=None, observation_filter="NoFilter", compress_observations=False, num_rollout_workers=1 )
    .framework(framework="tf2", eager_tracing=True)
    .training(model={"use_attention":True})
    .experimental( _disable_preprocessor_api=True)
)

analysis = tune.run(
    "PPO",
    metric = "episode_reward_mean",
    mode = "max",
    scheduler = pbt,
    num_samples=1,
    config = ppo_config.to_dict(),
    stop = stopping_criteria,
    local_dir = "D:/ray_results",
    checkpoint_freq = 1,
    raise_on_failed_trial=False
)

best_result =analysis.get_best_trial()
loaded_ppo = Algorithm.from_checkpoint(best_result.checkpoint.to_air_checkpoint())
loaded_policy = loaded_ppo.get_policy()

my_gym = ACustomGym()
obs,info = my_gym.reset()
done = False
total_r = 0
while not done:
    action,_,_ = loaded_policy.compute_single_action(obs)
    obs, reward, terminated, truncated, info = my_gym.step(action)
    done = terminated or truncated
    total_r = total_r + reward
print(total_r)

But then the error I get on line

action,_,_ = loaded_policy.compute_single_action(obs)

is this one:

Traceback (most recent call last):
  File "d:\generic\main_PB2.py", line 288, in <module>
    action,_,_ = loaded_policy.compute_single_action(obs)
  File "C:\ProgramData\Anaconda3\envs\py3108\lib\site-packages\ray\rllib\policy\policy.py", line 489, in compute_single_action
    out = self.compute_actions_from_input_dict(
  File "C:\ProgramData\Anaconda3\envs\py3108\lib\site-packages\ray\rllib\policy\eager_tf_policy.py", line 138, in _func
    return obj(self_, *args, **kwargs)
  File "C:\ProgramData\Anaconda3\envs\py3108\lib\site-packages\ray\rllib\policy\eager_tf_policy.py", line 190, in compute_actions_from_input_dict
    return super(TracedEagerPolicy, self).compute_actions_from_input_dict(
  File "C:\ProgramData\Anaconda3\envs\py3108\lib\site-packages\ray\rllib\policy\eager_tf_policy_v2.py", line 461, in compute_actions_from_input_dict
    ret = self._compute_actions_helper(
  File "C:\ProgramData\Anaconda3\envs\py3108\lib\site-packages\ray\rllib\policy\eager_tf_policy.py", line 96, in _func
    return func(*eager_args, **eager_kwargs)
  File "C:\ProgramData\Anaconda3\envs\py3108\lib\site-packages\tensorflow\python\util\traceback_utils.py", line 153, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "C:\ProgramData\Anaconda3\envs\py3108\lib\site-packages\ray\rllib\utils\threading.py", line 24, in wrapper
    return func(self, *a, **k)
  File "C:\ProgramData\Anaconda3\envs\py3108\lib\site-packages\ray\rllib\policy\eager_tf_policy_v2.py", line 821, in _compute_actions_helper
    dist_inputs, state_out = self.model(
  File "C:\ProgramData\Anaconda3\envs\py3108\lib\site-packages\ray\rllib\models\modelv2.py", line 259, in __call__
    res = self.forward(restored, state or [], seq_lens)
  File "C:\ProgramData\Anaconda3\envs\py3108\lib\site-packages\ray\rllib\models\tf\attention_net.py", line 478, in forward
    assert seq_lens is not None
AssertionError

I’ve found this which seems totally related but I can’t figure out how to apply the solution proposed here if that were the solution.

Please any help will be very much appreciated. Attention net improves the result of my training and I’m stuck not being able to use the trained policy because of this. I’m sure I’m missing something obvious, as it should be straigh forward to use a policy after tuning and trianing it, but I can’t figure it out

Thank you

Hi @PREJAN ,
The compute_single_action can be messy when it comes to the assumptions that are made by RLlib about what tensors on what device etc should be passed there.

Try using Algorithm.compute_single_action!
Apart from putting seq_lens in the call, you can try the new RL Modules API.
RL Modules have a simple forward_inference or forward_exploration method that does not need seq lens in most cases.