Restoring a policy or a keras model from a checkpoint

Hi!

How severe does this issue affect your experience of using Ray?

  • Medium: It contributes to significant difficulty to complete my task, but I can work around it.

The problem
During training I create checkpoints and export Keras models following the guidelines at Saving and Loading your RL Algorithms and Policies — Ray 2.3.0

path_to_checkpoint = algo.save()
policy = algo.get_policy()
policy.export_model(path_to_checkpoint + "/keras_model")

When I am trying to restore a policy I get a Runtime error from tf

    rllib_policy = Policy.from_checkpoint(checkpoint_dir + "/policies/default_policy")
  File "/home/sia/miniconda3/envs/ray2.3/lib/python3.9/site-packages/ray/rllib/policy/policy.py", line 269, in from_checkpoint
    return Policy.from_state(state)
  File "/home/sia/miniconda3/envs/ray2.3/lib/python3.9/site-packages/ray/rllib/policy/policy.py", line 302, in from_state
    return TFPolicy._tf1_from_state_helper(state)
  File "/home/sia/miniconda3/envs/ray2.3/lib/python3.9/site-packages/ray/rllib/policy/tf_policy.py", line 506, in _tf1_from_state_helper
    new_policy = pol_spec.policy_class(
  File "/home/sia/miniconda3/envs/ray2.3/lib/python3.9/site-packages/ray/rllib/algorithms/ppo/ppo_tf_policy.py", line 83, in __init__
    base.__init__(
  File "/home/sia/miniconda3/envs/ray2.3/lib/python3.9/site-packages/ray/rllib/policy/dynamic_tf_policy_v2.py", line 89, in __init__
    timestep, explore = self._init_input_dict_and_dummy_batch(existing_inputs)
  File "/home/sia/miniconda3/envs/ray2.3/lib/python3.9/site-packages/ray/rllib/policy/dynamic_tf_policy_v2.py", line 488, in _init_input_dict_and_dummy_batch
    ) = self._create_input_dict_and_dummy_batch(self.view_requirements, {})
  File "/home/sia/miniconda3/envs/ray2.3/lib/python3.9/site-packages/ray/rllib/policy/dynamic_tf_policy_v2.py", line 547, in _create_input_dict_and_dummy_batch
    input_dict[view_col] = get_placeholder(
  File "/home/sia/miniconda3/envs/ray2.3/lib/python3.9/site-packages/ray/rllib/utils/tf_utils.py", line 214, in get_placeholder
    return tf1.placeholder(
  File "/home/sia/miniconda3/envs/ray2.3/lib/python3.9/site-packages/tensorflow/python/ops/array_ops.py", line 3340, in placeholder
    raise RuntimeError("tf.placeholder() is not compatible with "
RuntimeError: tf.placeholder() is not compatible with eager execution.

The workaround is to restore an algorithm from a checkpoint, and then get a policy from it. It works fine, but I do not need the entire algorithm but only a policy from it.

I thought that I could also use a TensorFlow model for inference. I can successfully get it using tf.saved_model.load(checkpoint_dir + "/keras_model/"). But the issue is that predictions from a policy and a model are different. Ok policy is stochastic so they should not be exactly the same, but predictions were very different. At first, I thought that it is due to some default wrappers, but when I compared model weights from a policy and a model they were also different. Also, for different checkpoints sometimes weights are almost the same, but sometimes they differ a lot.

So, the question is how to properly restore a model or a policy from a checkpoint without the RLLIB algorithm initialization.

Environment info / some configs
Python: 3.9.16
Tensorflow: 2.11.0
Ray: 2.3.0
Ray algorithm is PPO, the framework is tf (tf2 with tracing raises an error during getting is_training from input_dict in forward method of TFModelV2, so it is also an issue, but I do not use it currently and since PPO is an on policy it should not be a huge problem I believe), and the model is a custom subclassing Keras model.

I have prepared an example derived from ray/custom_env.py at master · ray-project/ray · GitHub to reproduce the errors I mentioned. - rllib-env-model/rllib_env_model.ipynb at main · shmyak-ai/rllib-env-model · GitHub

The problem is in a subclass Keras model. So far I have found two requirements to use subclassing:

  1. A model instance must be called self.base_model to be saved,
  2. The model should be run once during initialization.