How to make checkpoint by ray.tune.run and load it?

What is the best way to :

  1. make checkpoint by ray.tune.run
  2. load checkpoint by PPOTrainer
  3. Show checkpoint by tensorflow model.summary ?

I try to make checkpoint by :

from ray import tune

tune.run("PPO",
         config={
             "env":"CartPole-v0",
             "framework": "tf2",
             "evaluation_interval":2,
             "evaluation_duration":20,
         },
         local_dir="cartpole",
         checkpoint_freq=2)

next, load model by:

from ray.rllib.agents.ppo.ppo import PPOTrainer
import gym

config = {
    "env": "CartPole-v0",
    "framework": "tf2",
    "evaluation_interval": 2,
    "evaluation_duration": 20,
}

agent = PPOTrainer(config=config)
agent.restore("./cartpole/PPO/PPO_CartPole-v0_23a86_00000_0_2022-06-30_16-24-39/checkpoint_000002/checkpoint-2")
agent.load_checkpoint()

env = gym.make("CartPole-v0")
obs = env.reset()
while True:
    action = agent.compute_action(obs)
    obs, reward, done, _ = env.step(action)
    if done:
        break
env.close()

but there is error:


 File "/home/ppirog/projects/Mastering-Reinforcement-Learning-with-Python/3D_observations/run_saved_checkpoint.py", line 13, in <module>
    agent.load_checkpoint()
TypeError: load_checkpoint() missing 1 required positional argument: 'checkpoint_path'
Exception ignored in: <function RolloutWorker.__del__ at 0x7fd2f9785e50>
Traceback (most recent call last):
  File "/home/ppirog/projects/Mastering-Reinforcement-Learning-with-Python/venv/lib/python3.8/site-packages/ray/util/tracing/tracing_helper.py", line 461, in _resume_span
TypeError: 'NoneType' object is not callable
Exception ignored in: <function RolloutWorker.__del__ at 0x7fd2f9772a60>
Traceback (most recent call last):
  File "/home/ppirog/projects/Mastering-Reinforcement-Learning-with-Python/venv/lib/python3.8/site-packages/ray/util/tracing/tracing_helper.py", line 461, in _resume_span
TypeError: 'NoneType' object is not callable
Process finished with exit code 1

At last i try to show model:

import tensorflow as tf
model_path="./cartpole/PPO/PPO_CartPole-v0_23a86_00000_0_2022-06-30_16-24-39/checkpoint_000002/"
new_model = tf.keras.models.load_model(model_path)

new_model.summary()

But there is error:

OSError: SavedModel file does not exist at: /home/ppirog/projects/Mastering-Reinforcement-Learning-with-Python/3D_observations/cartpole/PPO/PPO_CartPole-v0_2890c_00000_0_2022-06-30_16-03-19/checkpoint_000002/{saved_model.pbtxt|saved_model.pb}

@Peter_Pirog
For loading a checkpoint with RLlib agent, please refer to this example: Training APIs — Ray 1.13.0
Specifically,

Let me know how it goes!

1 Like

@xwjiang2010 Thank You, I will try it :slight_smile:

@xwjiang2010 , I tested the code it works fine :slight_smile:

Here is the code how to visualize the trained model:

import ray.rllib.agents.ppo as ppo
from tensorflow.keras.utils import plot_model

checkpoint_path=" " # <-in this place put the path to saved checkpoint
config={ } # <- in this place put configuration the same like during training process

trainer = ppo.PPOTrainer(config=config)
trainer.restore(checkpoint_path)
model=trainer.get_policy().model.base_model # this is typical tf2 model format 
# typical tensorflow visualization
model.summary()
plot_model(model, to_file='model_plot.png', show_shapes=True, show_layer_names=True,rankdir="TB",expand_nested=False,show_layer_activations=True)

Result is:

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
==================================================================================================
 observations (InputLayer)      [(None, 4)]          0           []                               
                                                                                                  
 fc_1 (Dense)                   (None, 256)          1280        ['observations[0][0]']           
                                                                                                  
 fc_value_1 (Dense)             (None, 256)          1280        ['observations[0][0]']           
                                                                                                  
 fc_2 (Dense)                   (None, 256)          65792       ['fc_1[0][0]']                   
                                                                                                  
 fc_value_2 (Dense)             (None, 256)          65792       ['fc_value_1[0][0]']             
                                                                                                  
 fc_out (Dense)                 (None, 2)            514         ['fc_2[0][0]']                   
                                                                                                  
 value_out (Dense)              (None, 1)            257         ['fc_value_2[0][0]']             
                                                                                                  
==================================================================================================
Total params: 134,915
Trainable params: 134,915
Non-trainable params: 0

1 Like