How to load observation filter without loading the whole agent?

I am currently training some agents on my local machine with limited ram. For my evaluations I just restore the policy from checkpoints instead of the whole agent. Which saves up a lot of memory as loading agent restores the whole experiment. Here is code for my policy class:

from os.path import join

from ray.rllib.models.preprocessors import get_preprocessor
from ray.rllib.policy.policy import Policy as RLLibPolicy
from ray.rllib.utils.spaces import space_utils

from src.rllib_utills.configs import get_checkpoint_configs


class Policy:
    """ Easy to use RlLib policy wrapper
    Also this class allow the user to compute actions on the saved policy without restoring the whole experiment
    """

    def __init__(self, checkpoint_path, obs_space, action_space):
        """
        loads the policy and the pre-processor for the observations
        :param checkpoint_path:
        :param obs_space:
        """
        configs = get_checkpoint_configs(checkpoint_path)
        policy_path = join(checkpoint_path, 'policies', 'default_policy')
        self.policy = RLLibPolicy.from_checkpoint(policy_path)
        self.use_lstm = configs["model"]["use_lstm"]
        self.lstm_use_prev_actions = configs["model"]["lstm_use_prev_action"]
        if self.use_lstm:
            self.state = self.policy.get_initial_state()

        self.preprocessor = get_preprocessor(obs_space)(obs_space)
        self.action_preprocessor = get_preprocessor(action_space)(action_space)
        self.prev_action = action_space.sample()

    def __call__(self, obs, explore=False):
        """
        Computes the action for a given observed state
        :param obs: the observed state
        :param explore: if turned on an action is sampled from the policy instead of returning the best action
        :return: the action returned by the policy
        """
        obs = self.preprocessor.transform(obs)

        if self.use_lstm:
            if self.lstm_use_prev_actions:
                action, self.state, _ = self.policy.compute_single_action(obs=obs, state=self.state, explore=explore,
                                                                          prev_action=self.prev_action)
            else:
                action, self.state, _ = self.policy.compute_single_action(obs=obs, state=self.state, explore=explore)

        else:
            action, _, _ = self.policy.compute_single_action(obs=obs, explore=explore)

        action = space_utils.unsquash_action(action, self.policy.action_space_struct)
        self.prev_action = action
        return action

It works well for my experiments as I can load policy for evaluation while training without going out of memory. The problem is recently I tried to use “MeanStdFilter” observation filter. The results I got were much better then without it. So how do I load the observation filter to use with policy compute_single_action without loading the whole agent?

We have an ongoing effort to enable using the filters with the raw policy as you described as part of the Connectors (Connectors (Alpha) — Ray 2.3.1) API.

The MeanStdFilter has been migrated already (ray/mean_std_filter.py at master · ray-project/ray · GitHub) so if you simply pass enable_connectors=True to your algorithm config during training and save your policy checkpoint, the filter should be applied out of the box when you restore your policy.

Note: The Connectors API is currently still in alpha, but feel free to respond with any issues you encounter

1 Like