What is the best way to load a specific policy’s weights from a checkpoint file to the be used as another trainers policy? I’d like to use a pre-trained model to evaluate the currently training one against in a marl setting.
Doing something like this doesn’t work for me:
loader = get_trainer_class(algo)(env="yaniv", config=config)
loader.load_checkpoint(checkpoint_path)
policy = loader.get_policy("policy_1").get_weights()
self.trainer.set_weights({
"eval_policy": policy
})
I think this is because it makes a new trainer with all the workers and what not, where as I just want the policy, and gives the following error:
(pid=56641) File "/home/jippo/Code/yaniv/yaniv-rl/yaniv_rl/utils/rllib/trainer.py", line 18, in setup
(pid=56641) loader.load_checkpoint(checkpoint_path)
(pid=56641) File "/home/jippo/.conda/envs/yaniv-torch/lib/python3.7/site-packages/ray/rllib/agents/trainer.py", line 755, in load_checkpoint
(pid=56641) self.__setstate__(extra_data)
(pid=56641) File "/home/jippo/.conda/envs/yaniv-torch/lib/python3.7/site-packages/ray/rllib/agents/trainer_template.py", line 191, in __setstate__
(pid=56641) Trainer.__setstate__(self, state)
(pid=56641) File "/home/jippo/.conda/envs/yaniv-torch/lib/python3.7/site-packages/ray/rllib/agents/trainer.py", line 1321, in __setstate__
(pid=56641) self.workers.local_worker().restore(state["worker"])
(pid=56641) File "/home/jippo/.conda/envs/yaniv-torch/lib/python3.7/site-packages/ray/rllib/evaluation/rollout_worker.py", line 1059, in restore
(pid=56641) self.sync_filters(objs["filters"])
(pid=56641) File "/home/jippo/.conda/envs/yaniv-torch/lib/python3.7/site-packages/ray/rllib/evaluation/rollout_worker.py", line 1026, in sync_filters
(pid=56641) assert all(k in new_filters for k in self.filters)
(pid=56641) AssertionError