Extract and display policy

Is there an easy way to extract the trained policy within ray.tune at each episode and display that policy applied to a grid (e.g. the optimal action for different values of states)? E.g. to do this within the render function.

Similar question on extracting history of actions/policies over an episode.

You could use custom callbacks to get the policy after each episode. Take a look at this example script here, which implements a on_episode_end() callback.

ray/rllib/examples/custom_metrics_and_callbacks.py.

You can get to the policy object in that method by doing policy = worker.policy_map["default_policy"]. Then you could evaluate it right there on some task?

2 Likes

Hi Sven,
So for some reason using episode.last_observation_for(AGENT_ID) inside on_episode_step() method for the custom callback class doesn’t work. It only ever returns the resetted observations (e.g. the initial observations each episode) and not the stepped forward observations unless I made some other mmistake. I very closely followed the custom metrics and callbacks example class too. I checked and inside my MultiAgentEnvironment and Environments it does work, as the observation is being updated in run.tune, so it must be some issue or design choice with episode.last_observation_for(). I’m therefore not really sure how to extract observations for easy display.

Is it possible that

For the policy this should work. I am trying to extract the action probabilities via .logp and then plot them on a state grid. Unfortunately, not sure how to get something like 95% intervals without manually sampling.

@sven1977 sorry, missed a sentence. Do you know if in the example class, episode.last_observation_for when run at the end of the episode is correctly calling period by period rather than returning the always reset observations?

This is my current Callback class. The problem arises even in on_episode_step despite my observations being updated in the actual MultiAgentEnv class step function.

class MyCallBack(DefaultCallbacks):
    """
    Modified callback will return graphs of consumption and eventually policy function when run.
    """
    def on_episode_start(self, *, worker: RolloutWorker, base_env: BaseEnv, policies: Dict[str, Policy],episode: MultiAgentEpisode, env_index: int, **kwargs):
        # Make sure this episode has just been started (only initial obs
        # logged so far).
        assert episode.length == 0, \
            "ERROR: `on_episode_start()` callback should be called right " \
            "after env reset!"
        print("episode {} (env-idx={}) started.".format(
            episode.episode_id, env_index))
        episode.user_data["consumption"] = []
        episode.user_data["savings"] = []
        episode.user_data["assets"] = []
        episode.user_data["net_savings"]= []
        episode.hist_data["consumption"] = []
        episode.hist_data["savings"] = []
        episode.hist_data["net_savings"] = []
        episode.hist_data["assets"] = []
        episode.user_data["all_assets"] = []
        episode.hist_data["all_assets"] = []


    def on_episode_step(self, *, worker: RolloutWorker, base_env: BaseEnv,
                        episode: MultiAgentEpisode, env_index: int, **kwargs):
        # Make sure this episode is ongoing.
        assert episode.length > 0, \
            "ERROR: `on_episode_step()` callback should not be called right " \
            "after env reset!"
        consumption= episode.last_action_for(str(0))[0]
        assets = episode.last_observation_for(str(0))[0]
        price = episode.last_observation_for(str(0))[1]
        income = episode.last_observation_for(str(0))[2]
        wage = episode.last_observation_for(str(0))[6]
        interest = episode.last_observation_for(str(0))[5]
        net_savings: float = interest*assets+income-consumption
        savings: float = price*assets+income-consumption

        for i in range(0,AGENT_NUM):
            all_asset_temp = episode.last_observation_for(str(i))[0]
            episode.user_data["all_assets"].append(all_asset_temp)
        episode.user_data["consumption"].append(consumption)
        episode.user_data["assets"].append(assets)
        episode.user_data["savings"].append(savings)
        episode.user_data["net_savings"].append(net_savings)


    def on_episode_end(self, *, worker: RolloutWorker, base_env: BaseEnv,
                       policies: Dict[str, Policy], episode: MultiAgentEpisode,
                       env_index: int, **kwargs):
        consumption_mean: float = np.mean(episode.user_data["consumption"])
        net_savings_mean: float = np.mean(episode.user_data["net_savings"])
        savings_mean: float = np.mean(episode.user_data["savings"])
        assets_mean: float = np.mean(episode.user_data["all_assets"])
        savings_var: float=  np.var(episode.user_data["savings"])
        net_savings_var: float = np.var(episode.user_data["net_savings"])
        consumption_var: float = np.var(episode.user_data["consumption"])
        print("episode {} (env-idx={}) ended with length {} and consumption mean and variance:  {}, {}"
        " and savings mean and variance {}, {}".format(episode.episode_id, env_index, episode.length, consumption_mean, consumption_var, savings_mean, savings_var))
        # Graphs of mean and var over time
        episode.custom_metrics["consumption_mean"] = consumption_mean
        episode.custom_metrics["consumption_var"] = consumption_var
        episode.custom_metrics["net_savings_mean"] = net_savings_mean
        episode.custom_metrics["net_savings_var"] = net_savings_var
        episode.custom_metrics["savings_mean"] = savings_mean
        episode.custom_metrics["savings_var"] = savings_var
        episode.custom_metrics["assets_mean"] = assets_mean
        
        episode.hist_data["all_assets"] = episode.user_data["all_assets"]
        episode.hist_data["net_savings"] = episode.user_data["net_savings"]
        episode.hist_data["assets"] = episode.user_data["assets"]
        episode.hist_data["savings"] = episode.user_data["savings"]
        episode.hist_data["consumption"] = episode.user_data["consumption"]

        # Graphs of Hist over time.
        episode.custom_metrics["consumption_hist"] = episode.hist_data["consumption"]
        episode.custom_metrics["assets_hist"] = episode.hist_data["assets"]
        episode.custom_metrics["savings_hist"] = episode.hist_data["savings"]
        episode.custom_metrics["net_savings_hist"] = episode.hist_data["net_savings"]

def on_sample_end(self, *, worker: RolloutWorker, samples: SampleBatch,
                      **kwargs):
        print("returned sample batch of size {}".format(samples.count))

    def on_train_result(self, *, trainer, result: dict, **kwargs):
        print("trainer.train() result: {} -> {} episodes".format(
            trainer, result["episodes_this_iter"]))
        # you can mutate the result dict to add new fields to return
        result["callback_ok"] = True


    def on_learn_on_batch(self, *, policy: Policy, train_batch: SampleBatch,
                          result: dict, **kwargs) -> None:
        # Right before learning, we want to display current policy.

        #logits =policy.model.from_batch({"obs": np.linspace([policy.observation_space.low[0],1,1,1,5,1,1],[policy.observation_space.high[0],1,1,1,5,1,1],10000)})
        #distributions 
        #prob_vec = []
        #state vec
        #obs_vec =np.linspace([policy.observation_space.low[0],1,1,1,5,1,1],[policy.observation_space.high[0],1,1,1,5,1,1],10000) 
        #for i in range(0,1000):
        #    logits =policy.model.from_batch({"obs": obs_vec[i]})
        #    probs = tf.nn.softmax(policy.dist_class(logits,policy.model))
        #    prob_vec.append(probs)
            
        result["sum_actions_in_train_batch"] = np.sum(train_batch["actions"])
        print("policy.learn_on_batch() result: {} -> sum actions: {}".format(
            policy, result["sum_actions_in_train_batch"]))
    def on_postprocess_trajectory(
            self, *, worker: RolloutWorker, episode: MultiAgentEpisode,
            agent_id: str, policy_id: str, policies: Dict[str, Policy],
            postprocessed_batch: SampleBatch,
            original_batches: Dict[str, SampleBatch], **kwargs):
        print("postprocessed {} steps".format(postprocessed_batch.count))
        if "num_batches" not in episode.custom_metrics:
            episode.custom_metrics["num_batches"] = 0
        episode.custom_metrics["num_batches"] += 1