Pre-train one type of policies in MARL

I try to pre-training one agent and then jointly train both type of agents. The codes are as follows,


import ray
from ray.rllib.algorithms import Algorithm
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.examples.rl_modules.classes.random_rlm import RandomRLModule

from Environment.multiagent_environment import SignalPlatoonEnv
from ray import tune
import os
import sys
from ray.rllib.core.rl_module.marl_module import MultiAgentRLModuleSpec
from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec
from RL_module import SPTorchMultiAgentModuleWithSharedEncoder,SPTorchRLModuleWithSharedGlobalEncoder,SP_PPOCatalog
if 'SUMO_HOME' in os.environ:
    tools = os.path.join(os.environ['SUMO_HOME'], 'tools')
    sys.path.append(tools)
else:
    sys.exit("please declare environment variable 'SUMO_HOME'")

from sumolib import checkBinary # noqa
import traci # noqa
from Environment.network_information import sa_get_relevant_agent



relevant_ss_set,relevant_sp_num=sa_get_relevant_agent()

def generate_module_specs(env,scenario,signal_agent_list):
    #policy->module
    module_spec_dict={}

    input_dim_dict=env.input_dim_dict
    contacter_dim_set={}
    hidden_dim=8
    contacter_out=32
    self_hidden_dim=32

    for policy in env.observation_space_policy:
        if scenario == "signal_platoon":
            if policy.startswith("signal"):
                relevant_agent_num = len(relevant_ss_set[policy]) + relevant_sp_num
            else:
                relevant_agent_num = 1
        else:
            if policy.startswith("signal"):
                relevant_agent_num = len(relevant_ss_set[policy])
            else:
                relevant_agent_num = 0
        contacter_dim = relevant_agent_num * hidden_dim+self_hidden_dim
        contacter_dim_set[policy]=contacter_dim

    for policy in env.observation_space_policy:
        module_spec_dict[policy] = SingleAgentRLModuleSpec(
                module_class=SPTorchRLModuleWithSharedGlobalEncoder,
                observation_space=env.observation_space_policy[policy], action_space=env.action_space_policy[policy],
                model_config_dict={"input_dim_dict":input_dim_dict,"hidden_dim": hidden_dim,"contacter_dim":contacter_dim_set[policy],"contacter_out":contacter_out,
            "scenario":scenario,"signal_agent_list":signal_agent_list,'self_hidden_dim':self_hidden_dim},catalog_class = SP_PPOCatalog)

    return module_spec_dict

ray.init()

netfile='Environment/network_2_3/network23.net.xml'
routefile='Environment/network_2_3/network23.rou.xml'
cfg_file='Environment/network_2_3/network23.sumocfg'

scenario='signal_platoon'

env=SignalPlatoonEnv(scenario=scenario,
                     net_file=netfile,
                     route_file=routefile,
                     cfg_file=cfg_file,
                     use_gui=False,
                     direct_start=True
                    )
policy_set,policy_map,signal_policies=env.get_policy_dict()
policy_map_fn = lambda agent_id,*args, **kwargs: policy_map[agent_id]
signal_agent_list =env.signal_agent_ids
module_specs=generate_module_specs(env,scenario,signal_agent_list)

config=(PPOConfig().experimental(_disable_preprocessor_api=True).api_stack(
    enable_rl_module_and_learner=True,
    enable_env_runner_and_connector_v2=True,
).env_runners(num_env_runners=2)  
).training(
            model={
                "_disable_preprocessor_api": True,
                "use_new_env_runners":True,
            },
    train_batch_size=500,
    use_kl_loss =True,  
        )
config.sample_timeout_s = 60  
config.rollout_fragment_length = 50  

tune.register_env("sp_env", lambda _: env)
config_pre_train_signal=config.environment("sp_env").multi_agent(policies=policy_set,policy_mapping_fn =policy_map_fn ,
                    policies_to_train=list(signal_policies)).rl_module(model_config_dict={"vf_share_layers": True},rl_module_spec= MultiAgentRLModuleSpec(
                  marl_module_class=SPTorchMultiAgentModuleWithSharedEncoder,
                  module_specs=module_specs))
pre_train=config_pre_train_signal.build()


for i in range(2):
    #iteration
    pre_train.train()

print('end_pre_train')

check_point_dir=pre_train.save().checkpoint.path

joint_train=Algorithm.from_checkpoint(check_point_dir,policy_ids=policy_set,policy_mapping_fn=policy_map_fn,policies_to_train=list(policy_set))
for i in range(2):
    result_joint_train=joint_train.train()

The error turns out at the first training, as follows,
File “D:\signal_platoon\train.py”, line 107, in
pre_train.train()
File “D:\signal_platoon\venv\Lib\site-packages\ray\tune\trainable\trainable.py”, line 331, in train
raise skipped from exception_cause(skipped)
File “D:\signal_platoon\venv\Lib\site-packages\ray\tune\trainable\trainable.py”, line 328, in train
result = self.step()
^^^^^^^^^^^
File “D:\signal_platoon\venv\Lib\site-packages\ray\rllib\algorithms\algorithm.py”, line 877, in step
train_results, train_iter_ctx = self._run_one_training_iteration()
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File “D:\signal_platoon\venv\Lib\site-packages\ray\rllib\algorithms\algorithm.py”, line 3158, in _run_one_training_iteration
results = self.training_step()
^^^^^^^^^^^^^^^^^^^^
File “D:\signal_platoon\venv\Lib\site-packages\ray\rllib\algorithms\ppo\ppo.py”, line 424, in training_step
return self._training_step_new_api_stack()
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File “D:\signal_platoon\venv\Lib\site-packages\ray\rllib\algorithms\ppo\ppo.py”, line 534, in _training_step_new_api_stack
self.metrics.peek(LEARNER_RESULTS, mid, LEARNER_RESULTS_KL_KEY)
File “D:\signal_platoon\venv\Lib\site-packages\ray\rllib\utils\metrics\metrics_logger.py”, line 665, in peek
ret = tree.map_structure(lambda s: s.peek(), self._get_key(key))
^^^^^^^^^^^^^^^^^^
File “D:\signal_platoon\venv\Lib\site-packages\ray\rllib\utils\metrics\metrics_logger.py”, line 899, in _get_key
_dict = _dict[key]
~~~~~^^^^^
KeyError: ‘mean_kl_loss’

I have set policy_set and policy_mapping_fn contains both signal and platoon policies, and set policies_to_train to only have signal policies. But it seems that the platoon policies still in training.

Do I write the codes in the correct way? how can I solve this error? thank you so much!