[RLlib] Problem with TFModelV2 loading after having saved one with `TFPolicy.export_model()`

Hello everybody! I am new to topic creation in general so I would request a bit of patience if I do not do everything correctly from the start :smiley:

Background

I use trainer.get_policy().export_model() export to get a TFModel. The model is exported succesfully.

What’s the problem?

I am unable to load with tf.saved_model.load() or use the model with SavedModel CLI.

Script

ray.init()
trainer = DQNTrainer(env="CartPole-v0")

for i in range(2):
   result = trainer.train()
   print(pretty_print(result))

   if i % 1 == 0:
       checkpoint = trainer.save()
       print("checkpoint saved at", checkpoint)


trainer.get_policy().export_model('test_model')
predict_fn = tf.saved_model.load('test_model')

Error message

This function will only be available through the v1 compatibility library as tf.compat.v1.saved_model.utils.build_tensor_info or tf.compat.v1.saved_model.build_tensor_info.
WARNING:tensorflow:Unable to create a python object for variable <tf.Variable 'default_policy/lr:0' shape=() dtype=float32_ref> because it is a reference variable. It may not be visible to training APIs. If this is a problem, consider rebuilding the SavedModel after running tf.compat.v1.enable_resource_variables().
WARNING:tensorflow:Unable to create a python object for variable <tf.Variable 'default_policy/timestep_1:0' shape=() dtype=int64_ref> because it is a reference variable. It may not be visible to training APIs. If this is a problem, consider rebuilding the SavedModel after running tf.compat.v1.enable_resource_variables().
WARNING:tensorflow:Unable to create a python object for variable <tf.Variable 'default_policy/global_step:0' shape=() dtype=int64_ref> because it is a reference variable. It may not be visible to training APIs. If this is a problem, consider rebuilding the SavedModel after running tf.compat.v1.enable_resource_variables().
WARNING:tensorflow:Unable to create a python object for variable <tf.Variable 'default_policy/lr:0' shape=() dtype=float32_ref> because it is a reference variable. It may not be visible to training APIs. If this is a problem, consider rebuilding the SavedModel after running tf.compat.v1.enable_resource_variables().
WARNING:tensorflow:Unable to create a python object for variable <tf.Variable 'default_policy/timestep_1:0' shape=() dtype=int64_ref> because it is a reference variable. It may not be visible to training APIs. If this is a problem, consider rebuilding the SavedModel after running tf.compat.v1.enable_resource_variables().
WARNING:tensorflow:Unable to create a python object for variable <tf.Variable 'default_policy/global_step:0' shape=() dtype=int64_ref> because it is a reference variable. It may not be visible to training APIs. If this is a problem, consider rebuilding the SavedModel after running tf.compat.v1.enable_resource_variables().
WARNING:tensorflow:Some variables could not be lifted out of a loaded function. Run the tf.initializers.tables_initializer() operation to restore these variables.
WARNING:tensorflow:Unable to create a python object for variable <tf.Variable 'default_policy/lr:0' shape=() dtype=float32_ref> because it is a reference variable. It may not be visible to training APIs. If this is a problem, consider rebuilding the SavedModel after running tf.compat.v1.enable_resource_variables().
WARNING:tensorflow:Unable to create a python object for variable <tf.Variable 'default_policy/timestep_1:0' shape=() dtype=int64_ref> because it is a reference variable. It may not be visible to training APIs. If this is a problem, consider rebuilding the SavedModel after running tf.compat.v1.enable_resource_variables().
WARNING:tensorflow:Unable to create a python object for variable <tf.Variable 'default_policy/global_step:0' shape=() dtype=int64_ref> because it is a reference variable. It may not be visible to training APIs. If this is a problem, consider rebuilding the SavedModel after running tf.compat.v1.enable_resource_variables().
Traceback (most recent call last):
  File "Python36\python.exe\lib\code.py", line 91, in runcode
    exec(code, self.locals)
  File "<input>", line 1, in <module>
  File "\plugins\python\helpers\pydev\_pydev_bundle\pydev_umd.py", line 197, in runfile
    pydev_imports.execfile(filename, global_vars, local_vars)  # execute the script
  File "\plugins\python\helpers\pydev\_pydev_imps\_pydev_execfile.py", line 18, in execfile
    exec(compile(contents+"\n", file, 'exec'), glob, loc)
  File "test.py", line 64, in <module>
    predict_fn = tf.saved_model.load('test_model')
  File "\venv\lib\site-packages\tensorflow\python\saved_model\load.py", line 859, in load
    return load_internal(export_dir, tags, options)["root"]
  File "\venv\lib\site-packages\tensorflow\python\saved_model\load.py", line 909, in load_internal
    root = load_v1_in_v2.load(export_dir, tags)
  File "\venv\lib\site-packages\tensorflow\python\saved_model\load_v1_in_v2.py", line 279, in load
    return loader.load(tags=tags)
  File "\venv\lib\site-packages\tensorflow\python\saved_model\load_v1_in_v2.py", line 262, in load
    signature_functions = self._extract_signatures(wrapped, meta_graph_def)
  File "\venv\lib\site-packages\tensorflow\python\saved_model\load_v1_in_v2.py", line 169, in _extract_signatures
    signature_fn = wrapped.prune(feeds=feeds, fetches=fetches)
  File "\venv\lib\site-packages\tensorflow\python\eager\wrap_function.py", line 338, in prune
    base_graph=self._func_graph)
  File "\venv\lib\site-packages\tensorflow\python\eager\lift_to_graph.py", line 260, in lift_to_graph
    add_sources=add_sources))
  File "\venv\lib\site-packages\tensorflow\python\ops\op_selector.py", line 413, in map_subgraph
    % (repr(init_tensor), repr(op), _path_from(op, init_tensor, sources)))
tensorflow.python.ops.op_selector.UnliftableError: A SavedModel signature needs an input for each placeholder the signature's outputs use. An output for signature 'serving_default' depends on a placeholder which is not an input (i.e. the placeholder is not fed a value).
Unable to lift tensor <tf.Tensor 'default_policy/zeros_like_1:0' shape=(?,) dtype=float32> because it depends transitively on placeholder <tf.Operation 'default_policy/timestep' type=Placeholder> via at least one path, e.g.: default_policy/zeros_like_1 (Fill) <- default_policy/zeros_like_1/Const (Const) <- default_policy/Assign (Assign) <- default_policy/timestep (Placeholder)

Additional info for SavedModel CLI

I also tried using this example and again have a problem loading the model created.

Finally I tried using the model with SavedModel CLI of tensorflow following this. The input I used was:

saved_model_cli run --dir test_model --tag_set serve --signature_def serving_default --inputs observations=test.npy

The test.npy has a simple numpy array created by np.array([[0.1, 0.2, 0.3, 0.4]])

The resulting error essentially was (I removed most of the error meassge I can add it later if wanted):

tensorflow.python.framework.errors_impl.InvalidArgumentError: You must feed a value for placeholder tensor 'default_policy/timestep' with dtype int64
         [[node default_policy/timestep (defined at C:\Users\lgravias\PycharmProjects\RL_lib_tester\venv\Scripts\saved_model_cli.exe\__main__.py:7) ]]

Which to my understanding means that there needs to be an input for the timestep which however is not included in the iputs of the model.

Am I understanding something incorrectly? Do I use the framework incorrectly? Or is this a bug?

Thanks to anyone who may take the time to help!

After some debugging and searching I have found a work around to the problem. Since I saw a post from @kepricon detailing the same issue I figured I should show my solution here.

The issue stems from the _build_signature_def() function of the TFPolicy class in the tf_policy.py file where the timestep is not added in the input_signature which is however needed from the exported model. The new function is the following:

def _build_signature_def(self):
    """Build signature def map for tensorflow SavedModelBuilder.
    """
    # build input signatures
    input_signature = self._extra_input_signature_def()
    input_signature["observations"] = \
        tf1.saved_model.utils.build_tensor_info(self._obs_input)

    if self._seq_lens is not None:
        input_signature["seq_lens"] = \
            tf1.saved_model.utils.build_tensor_info(self._seq_lens)
    
    ### THIS IS WHAT I ADDED ###
    if self._timestep is not None:
        input_signature["timestep"] = \
            tf1.saved_model.utils.build_tensor_info(self._timestep)
    ### END OF ADDITION ###
    if self._prev_action_input is not None:
        input_signature["prev_action"] = \
            tf1.saved_model.utils.build_tensor_info(
                self._prev_action_input)
    if self._prev_reward_input is not None:
        input_signature["prev_reward"] = \
            tf1.saved_model.utils.build_tensor_info(
                self._prev_reward_input)
    input_signature["is_training"] = \
        tf1.saved_model.utils.build_tensor_info(self._is_training)

    for state_input in self._state_inputs:
        input_signature[state_input.name] = \
            tf1.saved_model.utils.build_tensor_info(state_input)

    # build output signatures
    output_signature = self._extra_output_signature_def()
    for i, a in enumerate(tf.nest.flatten(self._sampled_action)):
        output_signature["actions_{}".format(i)] = \
            tf1.saved_model.utils.build_tensor_info(a)

    for state_output in self._state_outputs:
        output_signature[state_output.name] = \
            tf1.saved_model.utils.build_tensor_info(state_output)
    signature_def = (
        tf1.saved_model.signature_def_utils.build_signature_def(
            input_signature, output_signature,
            tf1.saved_model.signature_constants.PREDICT_METHOD_NAME))
    signature_def_key = (tf1.saved_model.signature_constants.
                         DEFAULT_SERVING_SIGNATURE_DEF_KEY)
    signature_def_map = {signature_def_key: signature_def}
    return signature_def_map

In order for this example to properly function I changed the train_and_export() function to the following using the new _build_signature_def() inside.

def train_and_export(algo_name, num_steps, model_dir, ckpt_dir, prefix):
    cls = get_agent_class(algo_name)
    alg = cls(config={}, env="CartPole-v0")
    for _ in range(num_steps):
        alg.train()

    # Export tensorflow checkpoint for fine-tuning
    # alg.export_policy_checkpoint(ckpt_dir, filename_prefix=prefix)
    policy = alg.get_policy()
    with policy._sess.graph.as_default():
        tf1.global_variables_initializer()
        builder = tf1.saved_model.builder.SavedModelBuilder(model_dir)
        signature_def_map = _build_signature_def(policy)
        builder.add_meta_graph_and_variables(
            policy._sess, [tf1.saved_model.tag_constants.SERVING],
            signature_def_map=signature_def_map,
            # saver=tf1.summary.FileWriter(model_dir).add_graph(graph=policy._sess.graph),
            strip_default_attrs=False)
        # builder.add_meta_graph([tf1.saved_model.tag_constants.SERVING], signature_def_map=signature_def_map, strip_default_attrs=True)
        builder.save()

Now the timestep is shown as an input to the exported model and a value can be passed through. I am not sure if the goal is for a timestep to be an actual input to the model or whether a default value is missing (I do not have a lot of experience in tensorflow). In any case I will problably add an issue on github as well.

1 Like

Hey @morsias , thanks for posting this problem and its solution. The timestep placeholder is not fed into the model directly, but is used by the Policy’s exploration components (sometimes), which are part of the graph (but not of the TFModelV2).

I’ll provide a PR for this fix.

1 Like

Actually, I’m not able to reproduce this problem. When I modify the rllib/tests/test_export.py test case to re-load the model via tf.saved_model.load(...) (after it has been saved via policy.export_model()), it works fine and I don’t see the error you described.

Could you provide a short, self-sufficient reproduction script that would produce this error?

@morsias ^

@sven1977 I think I have a similar issue.
I shared the reproducible script in my post.

First of all @sven1977 thank you for the response and sorry for taking a long time to response. Essentially, the problem is not in the exporting/loading process but in the inference process.

I used the following script to export a model based on this:

import os
import ray

from ray.rllib.agents.registry import get_agent_class
from ray.rllib.utils.framework import try_import_tf

tf1, tf, tfv = try_import_tf()

ray.init(num_cpus=10)


def train_and_export(algo_name, num_steps, model_dir, ckpt_dir, prefix):
    cls = get_agent_class(algo_name)
    alg = cls(config={}, env="CartPole-v0")
    for _ in range(num_steps):
        alg.train()

    # Export tensorflow checkpoint for fine-tuning
    alg.export_policy_checkpoint(ckpt_dir, filename_prefix=prefix)
    # Export tensorflow SavedModel for online serving
    alg.export_policy_model(model_dir)


def restore_saved_model(export_dir):
    signature_key = \
        tf1.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
    g = tf1.Graph()
    with g.as_default():
        with tf1.Session(graph=g) as sess:
            meta_graph_def = \
                tf1.saved_model.load(sess,
                                     [tf1.saved_model.tag_constants.SERVING],
                                     export_dir)
            print("Model restored!")
            print("Signature Def Information:")
            print(meta_graph_def.signature_def[signature_key])
            print("You can inspect the model using TensorFlow SavedModel CLI.")
            print("https://www.tensorflow.org/guide/saved_model")


def restore_checkpoint(export_dir, prefix):
    sess = tf1.Session()
    meta_file = "%s.meta" % prefix
    saver = tf1.train.import_meta_graph(os.path.join(export_dir, meta_file))
    saver.restore(sess, os.path.join(export_dir, prefix))
    print("Checkpoint restored!")
    print("Variables Information:")
    for v in tf1.trainable_variables():
        value = sess.run(v)
        print(v.name, value)


if __name__ == "__main__":
    algo = "DQN"
    model_dir = os.path.join("model_export_dir")
    ckpt_dir = os.path.join(ray.utils.get_user_temp_dir(), "ckpt_export_dir")
    prefix = "model.ckpt"
    num_steps = 3
    train_and_export(algo, num_steps, model_dir, ckpt_dir, prefix)
    restore_saved_model(model_dir)
    # restore_checkpoint(ckpt_dir, prefix)

Note that I changed get_trainer_class() with get_agent_class() and no other change since get_trainer_class() does not seem to work. The model is exported correctly.

In order to test the model at inference I used saved_model_cli. More specifically the following command:
saved_model_cli run --dir model_export_dir --tag_set serve --signature_def serving_default --input_exprs "observations=np.array([[ 0.1, 0.2, 0.3, 0.4]])"

and get the following error message:

Traceback (most recent call last):
  File "c:\users\lgravias\pycharmprojects\rl_lib_tester\venv\lib\site-packages\tensorflow\python\client\session.py", line 1375, in _do_call
    return fn(*args)
  File "c:\users\lgravias\pycharmprojects\rl_lib_tester\venv\lib\site-packages\tensorflow\python\client\session.py", line 1360, in _run_fn
    target_list, run_metadata)
  File "c:\users\lgravias\pycharmprojects\rl_lib_tester\venv\lib\site-packages\tensorflow\python\client\session.py", line 1453, in _call_tf_sessionrun
    run_metadata)
tensorflow.python.framework.errors_impl.InvalidArgumentError: You must feed a value for placeholder tensor 'default_policy/timestep' with dtype int64
         [[{{node default_policy/timestep}}]]

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "C:\Programm Files\Python36\python.exe\lib\runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "C:\Programm Files\Python36\python.exe\lib\runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "C:\Users\lgravias\PycharmProjects\RL_lib_tester\venv\Scripts\saved_model_cli.exe\__main__.py", line 7, in <module>
  File "c:\users\lgravias\pycharmprojects\rl_lib_tester\venv\lib\site-packages\tensorflow\python\tools\saved_model_cli.py", line 1192, in main
    args.func(args)
  File "c:\users\lgravias\pycharmprojects\rl_lib_tester\venv\lib\site-packages\tensorflow\python\tools\saved_model_cli.py", line 752, in run
    init_tpu=args.init_tpu, tf_debug=args.tf_debug)
  File "c:\users\lgravias\pycharmprojects\rl_lib_tester\venv\lib\site-packages\tensorflow\python\tools\saved_model_cli.py", line 450, in run_saved_model_with_feed_dict
    outputs = sess.run(output_tensor_names_sorted, feed_dict=inputs_feed_dict)
  File "c:\users\lgravias\pycharmprojects\rl_lib_tester\venv\lib\site-packages\tensorflow\python\client\session.py", line 968, in run
    run_metadata_ptr)
  File "c:\users\lgravias\pycharmprojects\rl_lib_tester\venv\lib\site-packages\tensorflow\python\client\session.py", line 1191, in _run
    feed_dict_tensor, options, run_metadata)
  File "c:\users\lgravias\pycharmprojects\rl_lib_tester\venv\lib\site-packages\tensorflow\python\client\session.py", line 1369, in _do_run
    run_metadata)
  File "c:\users\lgravias\pycharmprojects\rl_lib_tester\venv\lib\site-packages\tensorflow\python\client\session.py", line 1394, in _do_call
    raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.InvalidArgumentError: You must feed a value for placeholder tensor 'default_policy/timestep' with dtype int64
         [[node default_policy/timestep (defined at C:\Users\lgravias\PycharmProjects\RL_lib_tester\venv\Scripts\saved_model_cli.exe\__main__.py:7) ]]

Original stack trace for 'default_policy/timestep':
  File "C:\Programm Files\Python36\python.exe\lib\runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "C:\Programm Files\Python36\python.exe\lib\runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "C:\Users\lgravias\PycharmProjects\RL_lib_tester\venv\Scripts\saved_model_cli.exe\__main__.py", line 7, in <module>
    sys.exit(main())
  File "c:\users\lgravias\pycharmprojects\rl_lib_tester\venv\lib\site-packages\tensorflow\python\tools\saved_model_cli.py", line 1192, in main
    args.func(args)
  File "c:\users\lgravias\pycharmprojects\rl_lib_tester\venv\lib\site-packages\tensorflow\python\tools\saved_model_cli.py", line 752, in run
    init_tpu=args.init_tpu, tf_debug=args.tf_debug)
  File "c:\users\lgravias\pycharmprojects\rl_lib_tester\venv\lib\site-packages\tensorflow\python\tools\saved_model_cli.py", line 445, in run_saved_model_with_feed_dict
    loader.load(sess, tag_set.split(','), saved_model_dir)
  File "c:\users\lgravias\pycharmprojects\rl_lib_tester\venv\lib\site-packages\tensorflow\python\util\deprecation.py", line 340, in new_func
    return func(*args, **kwargs)
  File "c:\users\lgravias\pycharmprojects\rl_lib_tester\venv\lib\site-packages\tensorflow\python\saved_model\loader_impl.py", line 300, in load
    return loader.load(sess, tags, import_scope, **saver_kwargs)
  File "c:\users\lgravias\pycharmprojects\rl_lib_tester\venv\lib\site-packages\tensorflow\python\saved_model\loader_impl.py", line 453, in load
    **saver_kwargs)
  File "c:\users\lgravias\pycharmprojects\rl_lib_tester\venv\lib\site-packages\tensorflow\python\saved_model\loader_impl.py", line 383, in load_graph
    meta_graph_def, import_scope=import_scope, **saver_kwargs)
  File "c:\users\lgravias\pycharmprojects\rl_lib_tester\venv\lib\site-packages\tensorflow\python\training\saver.py", line 1485, in _import_meta_graph_with_return_elements
    **kwargs))
  File "c:\users\lgravias\pycharmprojects\rl_lib_tester\venv\lib\site-packages\tensorflow\python\framework\meta_graph.py", line 804, in import_scoped_meta_graph_with_return_elements
    return_elements=return_elements)
  File "c:\users\lgravias\pycharmprojects\rl_lib_tester\venv\lib\site-packages\tensorflow\python\util\deprecation.py", line 538, in new_func
    return func(*args, **kwargs)
  File "c:\users\lgravias\pycharmprojects\rl_lib_tester\venv\lib\site-packages\tensorflow\python\framework\importer.py", line 405, in import_graph_def
    producer_op_list=producer_op_list)
  File "c:\users\lgravias\pycharmprojects\rl_lib_tester\venv\lib\site-packages\tensorflow\python\framework\importer.py", line 513, in _import_graph_def_internal
    _ProcessNewOps(graph)
  File "c:\users\lgravias\pycharmprojects\rl_lib_tester\venv\lib\site-packages\tensorflow\python\framework\importer.py", line 243, in _ProcessNewOps
    for new_op in graph._add_new_tf_operations(compute_devices=False):  # pylint: disable=protected-access
  File "c:\users\lgravias\pycharmprojects\rl_lib_tester\venv\lib\site-packages\tensorflow\python\framework\ops.py", line 3680, in _add_new_tf_operations
    for c_op in c_api_util.new_tf_operations(self)
  File "c:\users\lgravias\pycharmprojects\rl_lib_tester\venv\lib\site-packages\tensorflow\python\framework\ops.py", line 3680, in <listcomp>
    for c_op in c_api_util.new_tf_operations(self)
  File "c:\users\lgravias\pycharmprojects\rl_lib_tester\venv\lib\site-packages\tensorflow\python\framework\ops.py", line 3561, in _create_op_from_tf_operation
    ret = Operation(c_op, self)
  File "c:\users\lgravias\pycharmprojects\rl_lib_tester\venv\lib\site-packages\tensorflow\python\framework\ops.py", line 1990, in __init__
    self._traceback = tf_stack.extract_stack()

From which one can understand that ‘default_policy/timestep’ must be feed a value. Note that I cannot feed a value on timestep because it is not a model input.

Now, changing the previous script to this:

#!/usr/bin/env python

import os
import ray
import numpy as np

from ray.rllib.agents.registry import get_agent_class
from ray.rllib.utils.framework import try_import_tf

tf1, tf, tfv = try_import_tf()

ray.init(num_cpus=10)


def train_and_export(algo_name, num_steps, model_dir, ckpt_dir, prefix):
    cls = get_agent_class(algo_name)
    alg = cls(config={}, env="CartPole-v0")
    for _ in range(num_steps):
        alg.train()

    # Export tensorflow checkpoint for fine-tuning
    # alg.export_policy_checkpoint(ckpt_dir, filename_prefix=prefix)
    policy = alg.get_policy()
    with policy._sess.graph.as_default():
        tf1.global_variables_initializer()
        builder = tf1.saved_model.builder.SavedModelBuilder(model_dir)
        signature_def_map = build_signature_def(policy)
        builder.add_meta_graph_and_variables(
            policy._sess, [tf1.saved_model.tag_constants.SERVING],
            signature_def_map=signature_def_map,
            saver=tf1.summary.FileWriter(model_dir).add_graph(graph=policy._sess.graph),
            )
        builder.save()



def build_signature_def(self):
    """Build signature def map for tensorflow SavedModelBuilder.
    """
    # build input signatures
    input_signature = self._extra_input_signature_def()
    input_signature["observations"] = \
        tf1.saved_model.utils.build_tensor_info(self._obs_input)

    if self._seq_lens is not None:
        input_signature["seq_lens"] = \
            tf1.saved_model.utils.build_tensor_info(self._seq_lens)

    ### THIS IS WHAT I ADDED ###
    if self._timestep is not None:
        input_signature["timestep"] = \
            tf1.saved_model.utils.build_tensor_info(self._timestep)
    if self._prev_action_input is not None:
        input_signature["prev_action"] = \
            tf1.saved_model.utils.build_tensor_info(
                self._prev_action_input)
    if self._prev_reward_input is not None:
        input_signature["prev_reward"] = \
            tf1.saved_model.utils.build_tensor_info(
                self._prev_reward_input)
    input_signature["is_training"] = \
        tf1.saved_model.utils.build_tensor_info(self._is_training)

    for state_input in self._state_inputs:
        input_signature[state_input.name] = \
            tf1.saved_model.utils.build_tensor_info(state_input)

    # build output signatures
    output_signature = self._extra_output_signature_def()
    for i, a in enumerate(tf.nest.flatten(self._sampled_action)):
        output_signature["actions_{}".format(i)] = \
            tf1.saved_model.utils.build_tensor_info(a)

    for state_output in self._state_outputs:
        output_signature[state_output.name] = \
            tf1.saved_model.utils.build_tensor_info(state_output)
    signature_def = (
        tf1.saved_model.signature_def_utils.build_signature_def(
            input_signature, output_signature,
            tf1.saved_model.signature_constants.PREDICT_METHOD_NAME))
    signature_def_key = (tf1.saved_model.signature_constants.
                         DEFAULT_SERVING_SIGNATURE_DEF_KEY)
    signature_def_map = {signature_def_key: signature_def}
    return signature_def_map


def restore_saved_model(export_dir):
    signature_key = \
        tf1.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
    g = tf1.Graph()
    with g.as_default():
        with tf1.Session(graph=g) as sess:
            meta_graph_def = \
                tf1.saved_model.load(sess,
                                     [tf1.saved_model.tag_constants.SERVING],
                                     export_dir)
            print("Model restored!")
            print("Signature Def Information:")
            print(meta_graph_def.signature_def[signature_key])
            print("You can inspect the model using TensorFlow SavedModel CLI.")
            print("https://www.tensorflow.org/guide/saved_model")
            return meta_graph_def


def restore_checkpoint(export_dir, prefix):
    sess = tf1.Session()
    meta_file = "%s.meta" % prefix
    saver = tf1.train.import_meta_graph(os.path.join(export_dir, meta_file))
    saver.restore(sess, os.path.join(export_dir, prefix))
    print("Checkpoint restored!")
    print("Variables Information:")
    for v in tf1.trainable_variables():
        value = sess.run(v)
        print(v.name, value)


if __name__ == "__main__":
    algo = "DQN"
    model_dir = "model_export_dir"
    ckpt_dir = os.path.join(ray.utils.get_user_temp_dir(), "ckpt_export_dir")
    prefix = "model.ckpt"
    num_steps = 3
    train_and_export(algo, num_steps, model_dir, ckpt_dir, prefix)
    exported_model = restore_saved_model(model_dir)
    # restore_checkpoint(ckpt_dir, prefix)


and essentially creating the correct _build_signature_def() (with timestep as model input) the model is once again exported with no issues.

When I now use the saved_model_cli as follows:
saved_model_cli run --dir model_export_dir --tag_set serve --signature_def serving_default --input_exprs "observations=np.array([[ 0.1, 0.2, 0.3, 0.4]]);timestep=1"

I get the expected result

Result for output key action_dist_inputs:
[[4.2693458 5.365109 ]]
Result for output key action_logp:
[0.]
Result for output key action_prob:
[1.]
Result for output key actions_0:
[0]
Result for output key q_values:
[[4.2693458 5.365109 ]]

I hope this code can help reproduce the problem. Let me know if further information is needed.