[RLlib] Multiagent with one pre-trained policy (vs another adversarial one)

Kathy J. asked this question on our slack channel:


Hi, I’m running a multiagent adversarial policy – in this setup, I’m interested in training 2 policies. First, I have a trained RLlib policy that I want to reload. Second, I want to start the adversary agent from scratch. Is there any documentation you can point me at to do this, or any example code? Thank you!

We don’t really support any pre-training settings/APIs right now. What you could do is to use the multiagent config sub-dict to set up your two policies, mapping of environment’s agent-IDs to which policy, etc…
Then you could pre-load the first policy (which you already trained) via something like:

trainer.get_policy("pol1").set_weights([pre-trained weights])

“pol1" would be the identifier for your first policy. This is the same identifier that you would use inside the “config->multiagent->policies” setting and “config->multiagent->policy_mapping_fn”.

Example: rllib/examples/multi_agent_custom_policy.py

What is the best way to load a specific policy’s weights from a checkpoint file to the be used as another trainers policy? I’d like to use a pre-trained model to evaluate the currently training one against in a marl setting.

Doing something like this doesn’t work for me:

        loader = get_trainer_class(algo)(env="yaniv", config=config)
        loader.load_checkpoint(checkpoint_path)
        policy = loader.get_policy("policy_1").get_weights()
        self.trainer.set_weights({
            "eval_policy": policy
        })

I think this is because it makes a new trainer with all the workers and what not, where as I just want the policy, and gives the following error:

(pid=56641)   File "/home/jippo/Code/yaniv/yaniv-rl/yaniv_rl/utils/rllib/trainer.py", line 18, in setup
(pid=56641)     loader.load_checkpoint(checkpoint_path)
(pid=56641)   File "/home/jippo/.conda/envs/yaniv-torch/lib/python3.7/site-packages/ray/rllib/agents/trainer.py", line 755, in load_checkpoint
(pid=56641)     self.__setstate__(extra_data)
(pid=56641)   File "/home/jippo/.conda/envs/yaniv-torch/lib/python3.7/site-packages/ray/rllib/agents/trainer_template.py", line 191, in __setstate__
(pid=56641)     Trainer.__setstate__(self, state)
(pid=56641)   File "/home/jippo/.conda/envs/yaniv-torch/lib/python3.7/site-packages/ray/rllib/agents/trainer.py", line 1321, in __setstate__
(pid=56641)     self.workers.local_worker().restore(state["worker"])
(pid=56641)   File "/home/jippo/.conda/envs/yaniv-torch/lib/python3.7/site-packages/ray/rllib/evaluation/rollout_worker.py", line 1059, in restore
(pid=56641)     self.sync_filters(objs["filters"])
(pid=56641)   File "/home/jippo/.conda/envs/yaniv-torch/lib/python3.7/site-packages/ray/rllib/evaluation/rollout_worker.py", line 1026, in sync_filters
(pid=56641)     assert all(k in new_filters for k in self.filters)
(pid=56641) AssertionError

Hey @Rory , could you try doing something (admittedly hacky :confused: ) like this:

# Get the weights dict/list of (learnt) policy1:
weights = my_trainer.get_policy([your policy's name or blank for "default_policy"]).get_weights()

# Transfer the weights to policy2:
my_new_trainer.get_policy([your policy's name or blank for "default_policy"]).set_weights(weights)