[RLlib] Problem with TFModelV2 loading after having saved one with `TFPolicy.export_model()`

After some debugging and searching I have found a work around to the problem. Since I saw a post from @kepricon detailing the same issue I figured I should show my solution here.

The issue stems from the _build_signature_def() function of the TFPolicy class in the tf_policy.py file where the timestep is not added in the input_signature which is however needed from the exported model. The new function is the following:

def _build_signature_def(self):
    """Build signature def map for tensorflow SavedModelBuilder.
    """
    # build input signatures
    input_signature = self._extra_input_signature_def()
    input_signature["observations"] = \
        tf1.saved_model.utils.build_tensor_info(self._obs_input)

    if self._seq_lens is not None:
        input_signature["seq_lens"] = \
            tf1.saved_model.utils.build_tensor_info(self._seq_lens)
    
    ### THIS IS WHAT I ADDED ###
    if self._timestep is not None:
        input_signature["timestep"] = \
            tf1.saved_model.utils.build_tensor_info(self._timestep)
    ### END OF ADDITION ###
    if self._prev_action_input is not None:
        input_signature["prev_action"] = \
            tf1.saved_model.utils.build_tensor_info(
                self._prev_action_input)
    if self._prev_reward_input is not None:
        input_signature["prev_reward"] = \
            tf1.saved_model.utils.build_tensor_info(
                self._prev_reward_input)
    input_signature["is_training"] = \
        tf1.saved_model.utils.build_tensor_info(self._is_training)

    for state_input in self._state_inputs:
        input_signature[state_input.name] = \
            tf1.saved_model.utils.build_tensor_info(state_input)

    # build output signatures
    output_signature = self._extra_output_signature_def()
    for i, a in enumerate(tf.nest.flatten(self._sampled_action)):
        output_signature["actions_{}".format(i)] = \
            tf1.saved_model.utils.build_tensor_info(a)

    for state_output in self._state_outputs:
        output_signature[state_output.name] = \
            tf1.saved_model.utils.build_tensor_info(state_output)
    signature_def = (
        tf1.saved_model.signature_def_utils.build_signature_def(
            input_signature, output_signature,
            tf1.saved_model.signature_constants.PREDICT_METHOD_NAME))
    signature_def_key = (tf1.saved_model.signature_constants.
                         DEFAULT_SERVING_SIGNATURE_DEF_KEY)
    signature_def_map = {signature_def_key: signature_def}
    return signature_def_map

In order for this example to properly function I changed the train_and_export() function to the following using the new _build_signature_def() inside.

def train_and_export(algo_name, num_steps, model_dir, ckpt_dir, prefix):
    cls = get_agent_class(algo_name)
    alg = cls(config={}, env="CartPole-v0")
    for _ in range(num_steps):
        alg.train()

    # Export tensorflow checkpoint for fine-tuning
    # alg.export_policy_checkpoint(ckpt_dir, filename_prefix=prefix)
    policy = alg.get_policy()
    with policy._sess.graph.as_default():
        tf1.global_variables_initializer()
        builder = tf1.saved_model.builder.SavedModelBuilder(model_dir)
        signature_def_map = _build_signature_def(policy)
        builder.add_meta_graph_and_variables(
            policy._sess, [tf1.saved_model.tag_constants.SERVING],
            signature_def_map=signature_def_map,
            # saver=tf1.summary.FileWriter(model_dir).add_graph(graph=policy._sess.graph),
            strip_default_attrs=False)
        # builder.add_meta_graph([tf1.saved_model.tag_constants.SERVING], signature_def_map=signature_def_map, strip_default_attrs=True)
        builder.save()

Now the timestep is shown as an input to the exported model and a value can be passed through. I am not sure if the goal is for a timestep to be an actual input to the model or whether a default value is missing (I do not have a lot of experience in tensorflow). In any case I will problably add an issue on github as well.

1 Like