RLLib Multiagent: Load only one policy from checkpoint & Compatibility of RLLib/Tune Checkpoints

I am working in a multiagent setup with 3 agents and I want to use pretrained weights for one (and only one!) of them. In my current workflow, I run my experiments using tune.run() and would prefer to keep it that way. Is it possible to restore only weights of one agent from a checkpoint in a clean way?

One (somewhat hacky) workaround I tried was calling a function before the tune.run() call that behaves as follows

  1. Initalize an rllib trainer1
  2. Load the checkpoint into trainer “trainer1”
  3. Get the weights for the agent via trainer1.get_weights(pretrain_agent)
  4. Initalize another random trainer (“trainer2”)
  5. Load the pretrained weights into trainer2
  6. Save trainer2 as checkpoint → Use this checkpoint for tune.run(restore=…)

However: This didn’t end up working because I couldn’t use the checkpoint generated from rllib that way in the tune.run() call (metadata missing). Is it possible to do this? That would then solve my main problem with the proposed workflow above

Hey @Rafael_Albert , great question. We should add an example to RLlib on how to do this.
Could you try doing the following and let me know? I’ll also PR this right now.

"""Simple example of how to restore only one of n agents from a trained
multi-agent Trainer using Ray tune.

The trick/workaround is to use an intermediate trainer that loads the
trained checkpoint into all policies and then reverts those policies
that we don't want to restore, then saves a new checkpoint, from which
tune can pick up training.

Control the number of agents and policies via --num-agents and --num-policies.
"""

import argparse
import gym
import os
import random

import ray
from ray import tune
from ray.rllib.agents.ppo import PPOTrainer
from ray.rllib.examples.env.multi_agent import MultiAgentCartPole
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.test_utils import check_learning_achieved

tf1, tf, tfv = try_import_tf()

parser = argparse.ArgumentParser()

parser.add_argument("--num-agents", type=int, default=4)
parser.add_argument("--num-policies", type=int, default=2)
parser.add_argument("--pre-training-iters", type=int, default=5)
parser.add_argument("--stop-iters", type=int, default=200)
parser.add_argument("--stop-reward", type=float, default=150)
parser.add_argument("--stop-timesteps", type=int, default=100000)
parser.add_argument("--num-cpus", type=int, default=0)
parser.add_argument("--as-test", action="store_true")
parser.add_argument(
    "--framework", choices=["tf2", "tf", "tfe", "torch"], default="tf")

if __name__ == "__main__":
    args = parser.parse_args()

    ray.init(num_cpus=args.num_cpus or None)

    # Get obs- and action Spaces.
    single_env = gym.make("CartPole-v0")
    obs_space = single_env.observation_space
    act_space = single_env.action_space

    # Setup PPO with an ensemble of `num_policies` different policies.
    policies = {
        f"policy_{i}": (None, obs_space, act_space, {})
        for i in range(args.num_policies)
    }
    policy_ids = list(policies.keys())

    def policy_mapping_fn(agent_id):
        pol_id = random.choice(policy_ids)
        return pol_id

    config = {
        "env": MultiAgentCartPole,
        "env_config": {
            "num_agents": args.num_agents,
        },
        # Use GPUs iff `RLLIB_NUM_GPUS` env var set to > 0.
        "num_gpus": int(os.environ.get("RLLIB_NUM_GPUS", "0")),
        "num_sgd_iter": 10,
        "multiagent": {
            "policies": policies,
            "policy_mapping_fn": policy_mapping_fn,
        },
        "framework": args.framework,
    }

    # Do some training and store the checkpoint.
    results = tune.run(
        "PPO",
        config=config,
        stop={"training_iteration": args.pre_training_iters},
        verbose=1,
        checkpoint_freq=1,
        checkpoint_at_end=True,
    )
    print("Pre-training done.")

    best_checkpoint = results.get_best_checkpoint(
        results.trials[0], mode="max")
    print(f".. best checkpoint was: {best_checkpoint}")

    # Create a new dummy Trainer to "fix" our checkpoint.
    new_trainer = PPOTrainer(config=config)
    # Get untrained weights for all policies.
    untrained_weights = new_trainer.get_weights()
    # Restore all policies from checkpoint.
    new_trainer.restore(best_checkpoint)
    # Set back all weights (except for 1st agent) to original
    # untrained weights.
    new_trainer.set_weights({
        pid: w for pid, w in untrained_weights.items()
        if pid != "policy_0"
    })
    # Create the checkpoint from which tune can pick up the
    # experiment.
    new_checkpoint = new_trainer.save()
    print(".. checkpoint to restore from (all policies reset, "
          f"except policy_0): {new_checkpoint}")

    print("Starting new tune.run")

    # Start our actual experiment.
    stop = {
        "episode_reward_mean": args.stop_reward,
        "timesteps_total": args.stop_timesteps,
        "training_iteration": args.stop_iters,
    }

    # Make sure, the non-1st policies are not updated anymore.
    config["multiagent"]["policies_to_train"] = [
        pid for pid in policy_ids if pid != "policy_0"
    ]

    results = tune.run(
        "PPO",
        stop=stop,
        config=config,
        verbose=1,
        restore=new_checkpoint,
    )

    if args.as_test:
        check_learning_achieved(results, args.stop_reward)
    ray.shutdown()

Here is the PR:

Aha this is what I was after a few days ago!

It’s kinda frustrating that you have to start up a trainer class with all the overheads it entails just to get the trainer weights out. If I want to get training weights from a checkpoint whilst all the cpus are used by tune what can I do?

This is what I ended up doing:

def get_policy_weights_from_checkpoint(trainer_class, checkpoint):
    run_base_dir = os.path.dirname(checkpoint)
    config_path = os.path.join(run_base_dir, 'params.pkl')
    with open(config_path, 'rb') as f:
        config = pickle.load(f)

    config['num_workers'] = 1
    config['evaluation_num_workers'] = 0
    eval_trainer = trainer_class(env="yaniv", config=config)
    eval_trainer.load_checkpoint(checkpoint)
    weights = eval_trainer.get_policy("policy_1").get_weights()
    eval_trainer.stop()
    
    return weights

Is there a way to start a trainer without any resources? I know a3c won’t let you

1 Like

Hey @Rory , I completely agree, one should be able to extract weights (of any policy and any model) from a checkpoint w/o having to start the entire trainer. What you did with num_workers=1 is a cool hack to at least somewhat alleviate this problem. But yeah, it’s currently not possible. I’ll add this valid point to our internal planning doc (we would like to go over the train/save/restore API and make it more fine-grained).

No worries :slight_smile: glad to hear its on the list of things to do! I’ve ended up just saving the latest checkpoint and then saving the models’ weights at intervals using pickle.

Cool :slight_smile: Not sure, but this may even be something for Tune to keep the Trainer(s) (>1 if more than 1 trials) alive after a tune.run and to return them somehow within the returned dict. Then the user would be responsible for destroying them, though, which normally shouldn’t be an issue.