Save off rllib policy activations from checkpoint rollout

I’ve trianed various RL agents with rllib (torch backend) on custom environments using a copy of rllib/train.py in which I registered my custom gym environment. I also have a copy of rllib/rollout.py for visualizing my trained agents on these custom environments. What I want to do next is some analysis on the trained agents activation patterns during a rollout of the trained agent.

I’ve identified one method for saving off activations of a torch policy as shown in this github gist. I’ve tried integrating this code into my copy of rollout.py by calling agent.get_policy() to get a copy of the policy and then registering a forward hook with the policy.model. However when I use this approach no activations get saved off. I’m guessing this is because agent.get_policy() is not returning the same policy that is getting called during the rollout.

Is there some way for me to access the policy that is getting used during the rollout other than agent.get_policy() which will allow me to do this kind of activation save off? I know I can alternitively save off the policy model and write my own script to load it and my env but this basically means I have to write my own version of rollout.py. It would be nice if I could do this using my custom version of rllib/rollout.py instead.

Actually, you are right. We recently changed rollout.py to use the Trainer’s evaluation workers (instead of the Policy in the Trainer’s local-worker) in order to allow parallelization in rollouts (previously, one could only rollout one epiode at a time).

You can, however, set num_workers in your rollout.py command-line provided config (–config ‘{“num_workers”: 0}’) to 0, then do this in your (custom) code to get the policy used for evaluation (rolling out):

policy_used_for_rollout = agent.evaluation_workers.local_worker().get_policy()

Thanks for the reply. I figured this out on my own yesterday. Here are a few code snippets showing how I implemented it for others to use if they want to achieve the same functionality:

# in my copy of rollout.py
    if args.save_policy_activations:
        config["num_workers"] = 0

    ray.init()

    # Create the Trainer from config.
    cls = get_trainable_cls(args.run)
    agent = cls(env=args.env, config=config)
...
...
    # setup introspection into policy activations
    if args.save_policy_activations:
        local_worker = agent.workers.local_worker()
        policy_activations = {}
        for policy_name, policy in local_worker.policy_map.items():
            print(f"found policy with name '{policy_name}'")
            print("Saving activations for the following layers:")
            policy_activations[policy_name] = collections.defaultdict(list)
            def save_activation(name, mod, inp, out):
                policy_activations[policy_name][name].append(out.cpu())
            for name, mod in policy.model.named_modules():
                if name in args.save_policy_activations:
                    mod.register_forward_hook(partial(save_activation, name))
                    print(f"\t{name} -- {mod.__class__}")
1 Like

Great, thanks a lot @josephcarmack for sharing this solution!

There was a small bug in the way I was implementing this in my previous post causing it not to work with multi-agent policies. Here is the update:

# in my rollout.py
...
...
    # Create the Trainer from config.
    cls = get_trainable_cls(args.run)
    agent = cls(env=args.env, config=config)
...
...
    # setup introspection into policy activations
    if args.save_policy_activations:
        # define policy hook function generator
        def getForwardHookFunc(activations_dict, policy_name):
            """returns a forward hook fucntion to save off activations for a given policy."""
            def hook(name, mod, inp, out):
                activations_dict[policy_name][name].append(out.cpu())
            return hook

        local_worker = agent.workers.local_worker()
        policy_activations = {}
        # registor forward hook function for each policy
        # there will be multiple policies for multi-agent/heirarchical agent envs
        for policy_name, policy in local_worker.policy_map.items():
            print(f"found policy with name '{policy_name}'")
            print("Saving activations for the following layers:")
            policy_activations[policy_name] = collections.defaultdict(list)
            # create a forward hook function for each policy
            hook_func = getForwardHookFunc(policy_activations, policy_name)
            for name, mod in policy.model.named_modules():
                if name in args.save_policy_activations:
                    mod.register_forward_hook(partial(hook_func, name))
                    print(f"\t{name} -- {mod.__class__}")
1 Like