[rllib] Problem running compute_single_action from PPO restored checkpoint

I am restoring a checkpoint from a Tune experiment trial and am attempting to manually compute a single action from a policy given an observation.


obs = algo.get_policy("learned").observation_space.sample() # guaranteed to be the right dimensions
lstm_cell_size = algo.config.model["lstm_cell_size"]
state = [np.zeros([lstm_cell_size], np.float32) for _ in range(2)] # create an empty state

algo.compute_single_action(obs, state, policy_id="learned", explore=False)

algo is just a restored PPO instance. It appears that I must pass in a state to compute_single_action since I have use_lstm set to True but I get the following error (truncated):

File ~/miniconda3/envs/test/lib/python3.10/site-packages/ray/rllib/core/models/base.py:417, in StatefulActorCriticEncoder._forward(self, inputs, **kwargs)
    415 actor_inputs = inputs.copy()
    416 critic_inputs = inputs.copy()
--> 417 actor_inputs[STATE_IN] = inputs[STATE_IN][ACTOR]
    418 critic_inputs[STATE_IN] = inputs[STATE_IN][CRITIC]
    420 actor_out = self.actor_encoder(actor_inputs, **kwargs)

IndexError: too many indices for tensor of dimension 3

What gives? I’m not sure how to properly pass in the state (or even the dimensions of it, apparently).

My config for evaluation is identical to the config used for training with the exception of setting explore in the evaluation config to False:

config = (  # 1. Configure the algorithm,
    .training(model={"fcnet_hiddens": [64, 64]})
    .evaluation(evaluation_num_workers=2, evaluation_config={"explore": False})
            "random_action": PolicySpec(
                observation_space=gym.spaces.Box(-1e18, 1e18, (4,)),
            "learned": PolicySpec(
                    model={"use_lstm": True},
                observation_space=gym.spaces.Box(-1e18, 1e18, (4,)),
config.sgd_minibatch_size = 128
config.train_batch_size = int(256 * 4)
config.env_config = {"num_agents": 50, "episode_length": 1000}

To be sure, these checkpoints were created during trials created during a call to tune.Tuner.fit(). I read other material with similar problems and tried to follow them to no avail.

Solved, I was passing an incorrect state (should be a dict for PPO). Instead, do:

state = a.get_policy("learned").get_initial_state()