The model is adopted from ComplexInputNetwork and modified for centralized critic usage.
When I want to restore the model the following error happened, quite confused why there are unexpected keys within the state_dict.
config = {
"num_workers": 1,
"num_cpus_per_worker": 1,
"num_gpus_per_worker": 0,
"multiagent": {
"policies": {
"shared_policy": (None, CoverageEnv.single_agent_observation_space,
CoverageEnv.single_agent_action_space,
{"framework": "torch"}),
},
"policy_mapping_fn": (lambda aid: "shared_policy"),
},
}
cfg = update_dict(cfg, config)
trainer = CCTrainer(
env=cfg['env'],
config=cfg
)
trainer.restore(str(checkpoint_file))
Traceback (most recent call last):
File "/Users/liuyungkai/PycharmProjects/cpp_grid/evaluate.py", line 123, in <module>
run_trial(checkpoint_path=checkpoint_path, render=False)
File "/Users/liuyungkai/PycharmProjects/cpp_grid/evaluate.py", line 70, in run_trial
trainer.restore(str(checkpoint_file))
File "/Users/liuyungkai/opt/anaconda3/envs/playground/lib/python3.7/site-packages/ray/tune/trainable.py", line 372, in restore
self.load_checkpoint(checkpoint_path)
File "/Users/liuyungkai/opt/anaconda3/envs/playground/lib/python3.7/site-packages/ray/rllib/agents/trainer.py", line 763, in load_checkpoint
self.__setstate__(extra_data)
File "/Users/liuyungkai/opt/anaconda3/envs/playground/lib/python3.7/site-packages/ray/rllib/agents/trainer_template.py", line 223, in __setstate__
Trainer.__setstate__(self, state)
File "/Users/liuyungkai/opt/anaconda3/envs/playground/lib/python3.7/site-packages/ray/rllib/agents/trainer.py", line 1390, in __setstate__
self.workers.local_worker().restore(state["worker"])
File "/Users/liuyungkai/opt/anaconda3/envs/playground/lib/python3.7/site-packages/ray/rllib/evaluation/rollout_worker.py", line 1099, in restore
self.policy_map[pid].set_state(state)
File "/Users/liuyungkai/opt/anaconda3/envs/playground/lib/python3.7/site-packages/ray/rllib/policy/torch_policy.py", line 616, in set_state
super().set_state(state)
File "/Users/liuyungkai/opt/anaconda3/envs/playground/lib/python3.7/site-packages/ray/rllib/policy/policy.py", line 482, in set_state
self.set_weights(state)
File "/Users/liuyungkai/opt/anaconda3/envs/playground/lib/python3.7/site-packages/ray/rllib/policy/torch_policy.py", line 574, in set_weights
self.model.load_state_dict(weights)
File "/Users/liuyungkai/opt/anaconda3/envs/playground/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1052, in load_state_dict
self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for ComplexInputNetworkandCentrailzedCritic:
Missing key(s) in state_dict: "cnn_all_channel._convs.0._model.1.weight", "cnn_all_channel._convs.0._model.1.bias", "cnn_all_channel._convs.1._model.1.weight", "cnn_all_channel._convs.1._model.1.bias", "cnn_all_channel._convs.2._model.1.weight", "cnn_all_channel._convs.2._model.1.bias", "cnn_all_channel._convs.3._model.0.weight", "cnn_all_channel._convs.3._model.0.bias", "cnn_all_channel._value_branch_separate.0._model.1.weight", "cnn_all_channel._value_branch_separate.0._model.1.bias", "cnn_all_channel._value_branch_separate.1._model.1.weight", "cnn_all_channel._value_branch_separate.1._model.1.bias", "cnn_all_channel._value_branch_separate.2._model.1.weight", "cnn_all_channel._value_branch_separate.2._model.1.bias", "cnn_all_channel._value_branch_separate.3._model.0.weight", "cnn_all_channel._value_branch_separate.3._model.0.bias", "cnn_all_channel._value_branch_separate.4._model.0.weight", "cnn_all_channel._value_branch_separate.4._model.0.bias", "cnn_global_critic._convs.0._model.1.weight", "cnn_global_critic._convs.0._model.1.bias", "cnn_global_critic._convs.1._model.1.weight", "cnn_global_critic._convs.1._model.1.bias", "cnn_global_critic._convs.2._model.1.weight", "cnn_global_critic._convs.2._model.1.bias", "cnn_global_critic._convs.3._model.0.weight", "cnn_global_critic._convs.3._model.0.bias", "cnn_global_critic._value_branch_separate.0._model.1.weight", "cnn_global_critic._value_branch_separate.0._model.1.bias", "cnn_global_critic._value_branch_separate.1._model.1.weight", "cnn_global_critic._value_branch_separate.1._model.1.bias", "cnn_global_critic._value_branch_separate.2._model.1.weight", "cnn_global_critic._value_branch_separate.2._model.1.bias", "cnn_global_critic._value_branch_separate.3._model.0.weight", "cnn_global_critic._value_branch_separate.3._model.0.bias", "cnn_global_critic._value_branch_separate.4._model.0.weight", "cnn_global_critic._value_branch_separate.4._model.0.bias", "post_fc_stack._hidden_layers.0._model.0.weight", "post_fc_stack._hidden_layers.0._model.0.bias", "post_fc_stack._hidden_layers.1._model.0.weight", "post_fc_stack._hidden_layers.1._model.0.bias", "post_fc_stack._value_branch_separate.0._model.0.weight", "post_fc_stack._value_branch_separate.0._model.0.bias", "post_fc_stack._value_branch_separate.1._model.0.weight", "post_fc_stack._value_branch_separate.1._model.0.bias", "post_fc_stack._value_branch._model.0.weight", "post_fc_stack._value_branch._model.0.bias", "logits_layer._model.0.weight", "logits_layer._model.0.bias", "value_layer._model.0.weight", "value_layer._model.0.bias", "vf_post_fc_stack._hidden_layers.0._model.0.weight", "vf_post_fc_stack._hidden_layers.0._model.0.bias", "vf_post_fc_stack._hidden_layers.1._model.0.weight", "vf_post_fc_stack._hidden_layers.1._model.0.bias", "vf_post_fc_stack._value_branch_separate.0._model.0.weight", "vf_post_fc_stack._value_branch_separate.0._model.0.bias", "vf_post_fc_stack._value_branch_separate.1._model.0.weight", "vf_post_fc_stack._value_branch_separate.1._model.0.bias", "vf_post_fc_stack._value_branch._model.0.weight", "vf_post_fc_stack._value_branch._model.0.bias", "central_vf_layer._model.0.weight", "central_vf_layer._model.0.bias".
Unexpected key(s) in state_dict: "weights", "global_timestep", "_exploration_state". ```