[rllib] How to infer Dict type obs with exported model?


The RLlib use StructuredTensor deal with the complex obs space.

What’s the problem?

I use trainer.get_policy().export_model() export rllib CHECKPOINT model into TFModel(pd files). The export process works well. But when i want to use the TFModel do inference, i find the observation requires A Tensor Type and my input_dict is a Dict with action mask. So, should i convert my dict inputs into StructuredTensor?


the input_dict is a Dict like this

{"avail_action": np.array([0.0] * 59), "action_mask": np.array([0.0] * 59), "state": np.zeros(shape=(4, 27, 8))}

predict_fn = tf.saved_model.load(exported_model_path)
infer = predict_fn.signatures["serving_default"]
outputs = infer(observations=tf.constant([input_dict['state'].tolist()], dtype=tf.float32),
                    prev_reward=tf.constant(0.0), is_training=tf.constant(False),
                    seq_lens=tf.constant(0, dtype=tf.int32), prev_action=tf.constant(1, dtype=tf.int64))

The observations Requirement like this:

As we can see, the observations is the flattened input_dict, but i can’t find the way to convert A Dict Type into StructredTensor.

Any suggestion will be helpful!

Great question! For the inference, are you using Trainer.compute_actions, Policy.compute_actions, or Policy.compute_single_action, or Trainer.compute_action? :confused: Yes, we need to clean this up a little and make it more intuitive. The Trainer methods will actually accept an observation from the environment, then preprocess it (e.g. flatten the dict). The Policy methods require an already preprocessed input.

Yes! if I use Trainer.compute_action(), the trainer can deal with the Dict type input.(except this error, i submit the Issue: Trainer.compute_action Error with Dict type observation inputs)

But when I export the TF Model from PPO.Trainer(got pd file), the Policy Methods require the preprocessed input. And I don’t know how to convert A Dict input into A Flattened input.

So, can RLlib provide a function do this preprocess(flatten the dict). Or if you can tell me the details of this flatten operation, maybe I can implement myself.

Thanks you guys for achieving such an excellent project! :grinning:

Hey @hybug.
Yes, the Policy object always requires the already flattened obs as inputs into its compute_action methods.
The Trainer will do the preprocessing itself, so you could simply use Trainer.compute_action or Trainer.compute_actions and pass in the Dict observation.

I fixed the issue. The bug was that the spaces for the local worker were not translated into their original forms, so the local worker did not create a proper preprocessor to be used with Trainer.compute_action()

1 Like