[Tune][RLlib] How to use a Tune-trained (RNN) model for inference?

I have a relatively fundamental question. Namely, I would like to know how to load a custom RNN model that I trained on a custom environment with Tune to make predictions with. I have Tune configured to create checkpoints during training.
Suppose I have something like:


The model and environment are configured in config according to the tutorials and examples in Github.

I got an idea in the custom_rnn_model.py how to do manual inference using the compute_actions function, however it is not clear to me how to load a model from a checkpoint and then do manual inference with it. Are there any examples of this or can someone tell me how to do this?
Extract from custom_rnn_model.py

    # To run the Trainer without tune.run, using our RNN model and
    # manual state-in handling, do the following:

    # Example (use `config` from the above code):
    # >> import numpy as np
    # >> from ray.rllib.agents.ppo import PPOTrainer
    # >>
    # >> trainer = PPOTrainer(config)
    # >> lstm_cell_size = config["model"]["custom_model_config"]["cell_size"]
    # >> env = RepeatAfterMeEnv({})
    # >> obs = env.reset()
    # >>
    # >> # range(2) b/c h- and c-states of the LSTM.
    # >> init_state = state = [
    # ..     np.zeros([lstm_cell_size], np.float32) for _ in range(2)
    # .. ]
    # >>
    # >> while True:
    # >>     a, state_out, _ = trainer.compute_action(obs, state)
    # >>     obs, reward, done, _ = env.step(a)
    # >>     if done:
    # >>         obs = env.reset()
    # >>         state = init_state
    # >>     else:
    # >>         state = state_out

The file shows that I need to create an initial state for the RNN and then use the commonly known loop to step through the environment until the environment returns done. My question is, how can I initialize the model (or the PPOTrainer) with one of my checkpoints?

Hi @LukasNothhelfer,

trainer.get_initial_state() should be able to provide the initial state for you.

Take a look at this code snippet from rollout.py that shows one way to restore weights and state from a checkpoint. In this example agent == trainer.

1 Like

@mannyv Thx for sharing the snippet. I ll try it out and let you know.

A first comment: The trainer class has not method get_initial_state() (Trainer class: ray/trainer.py at master · ray-project/ray · GitHub). But you can get the initial state via trainer.get_policy().get_initial_state()

The official documentation tells how it is done. I had overlooked that. Thank you @manny for the efforts.
Official documentation