Getting Critic Output for Given Observation


I have an idea that requires me to be able to feed a network with a critic an observation and get the critic’s evaluation. I’ve been searching the web for a way to do this for a while now and am coming up with nothing. Is there a way to do this? (I’m using APPO if that matters.)

Thank you!

Hi automata,

The following code shows how you can ‘manually’ get the value prediction out of your model:

import gym
import ray
from ray import tune
from ray.rllib.agents.ppo import APPOTrainer
from ray.rllib.utils.framework import try_import_tf

tf1, tf, tfv = try_import_tf()

if __name__ == "__main__":

    config = {
        "env": "CartPole-v0",
        "framework": "tfe",

    # Do some training and store the checkpoint.
    results =
        stop={"training_iteration": 1},
    best_checkpoint = results.get_best_checkpoint(
        results.trials[0], mode="max")

    new_trainer = APPOTrainer(config=config)

    policy = new_trainer.get_policy()

    single_env = gym.make("CartPole-v0")
    obs = single_env.reset()

    policy.compute_actions_from_input_dict({"obs": [obs]})
    value_function_output_for_obs = policy.model.value_function()



It is mostly stitched together from examples.


1 Like

Thank you! This works well in the example, however in my code (too long to post here) the result of policy.model.value_function() is a Tensor("Reshape:0", shape=(?,), dtype=float32), strangely enough. I can run in “tfe” mode and evaluate that to a numpy array, but that slows things down significantly so I’d like to get something with a value if I can. I’ll try making a minimal code example that does this, in the mean time I’m wondering if this is something you might know the reason for right off. Thank you!

Here is a minimal example:

import gym
import numpy
import os
import ray
import ray.rllib

class Sim(ray.rllib.env.multi_agent_env.MultiAgentEnv):
    def __init__(self, config):
        self.policy = None
        self.observation_space = gym.spaces.Box(numpy.float32("-inf"), numpy.float32("inf"), shape=(2,), dtype=numpy.float32)
        self.action_space = gym.spaces.MultiDiscrete([3]*2)
    def reset(self):
        observations = {}
        observations["robot0"] = [0,0]
        observations["robot1"] = [0,1]
        observations["robot2"] = [1,0]
        observations["robot3"] = [1,1]
        # seems on_episode_start() is called after reset() so I don't know how to get the policy on the first pass
        if self.policy:
            self.policy.compute_actions_from_input_dict({"obs":numpy.asarray([observations["robot0"], observations["robot1"], observations["robot2"], observations["robot3"]])})
        return observations
    def step(self, action):
        observations = {}
        observations["robot0"] = [0,0]
        observations["robot1"] = [0,1]
        observations["robot2"] = [1,0]
        observations["robot3"] = [1,1]
        rewards = {}
        rewards["robot0"] = 0
        rewards["robot1"] = 0
        rewards["robot2"] = 0
        rewards["robot3"] = 0
        done = {"__all__": True}
        return observations, rewards, done, {}

class Callbacks(ray.rllib.agents.callbacks.DefaultCallbacks):
    def __init__(self, legacy_callbacks_dict=None):
        super(Callbacks, self).__init__(legacy_callbacks_dict)
    def on_episode_start(self, *, worker, base_env, policies, episode, **kwargs):
        env = base_env.get_unwrapped()[episode.env_id]
        if not env.policy:
            env.policy = policies["default_policy"]

SCRIPT_DIRECTORY = os.path.dirname(os.path.realpath(__file__))
ray.init()"APPO", config={"env": Sim, "env_config": {}, "callbacks": Callbacks, "model": {"fcnet_hiddens": [256,256], "fcnet_activation": "relu"}, "num_workers": 0, "num_envs_per_worker": 1, "num_gpus": 0.0}, local_dir=SCRIPT_DIRECTORY, name="Output")

This prints out stuff like:

(pid=146255) Tensor("Reshape:0", shape=(?,), dtype=float32)
(pid=146255) Tensor("Reshape_1:0", shape=(?,), dtype=float32)
(pid=146255) Tensor("Reshape_2:0", shape=(?,), dtype=float32)
(pid=146255) Tensor("Reshape_3:0", shape=(?,), dtype=float32)
(pid=146255) Tensor("Reshape_4:0", shape=(?,), dtype=float32)

I think I found out the problem, I was using the default “tf” framework. Switching to “tf2” seems to solve the issue.

Hi automata,
Great! tf2 executes eagerly by default. The output that you saw, "Tensor("Reshape:0", shape=(?,), dtype=float32)", is a symbolic tensor. Executing eagerly lets you get rid of these!


1 Like