Trouble migrating legacy build_trainer() code

Hello! I am trying to migrate the legacy RLLib code based on build_trainer() below. It is from this repo.

    HandCodedTrainer = tt.build_trainer("HandCoded", wsr.SimplePolicy)
    ext_conf = {
            "multiagent": {
                "policies": filter_keys(policies, ['baseline_producer', 'baseline_consumer']),
                "policy_mapping_fn": lambda agent_id: 'baseline_producer' if Utils.is_producer_agent(agent_id) else 'baseline_consumer',
                "policies_to_train": ['baseline_producer', 'baseline_consumer']
            }
        }
    handcoded_trainer = HandCodedTrainer(
        env = wsr.WorldOfSupplyEnv,
        config = dict(base_trainer_config, **ext_conf))

    for i in range(n_iterations):
        print("== Iteration", i, "==")
        print_training_results(handcoded_trainer.train())

Below is the revised code for RLLib 2.9

class Baseline(PPO):
    @classmethod
    def get_default_policy_class(cls, config):
        return SimplePolicy

and

    algo = (
        PPOConfig(Baseline)
            .rollouts(num_rollout_workers=0, rollout_fragment_length=50, batch_mode='complete_episodes')
            .framework('tf2')
            .resources(num_gpus=0)
            .training(
                train_batch_size=2000,
                gamma=0.99
            )
            .multi_agent(
                policies=filter_keys(policies, ['baseline_producer', 'baseline_consumer']),
                policy_mapping_fn=lambda agent_id, episode, worker:
                                    'baseline_producer' if Utils.is_producer_agent(agent_id) else 'baseline_consumer',
                policies_to_train=['baseline_producer', 'baseline_consumer']
            )
            .environment(
                env=WorldOfSupplyEnv,
                env_config=env_config,
                disable_env_checking=True
            )
            .build()
    )
    
    for i in range(n_iterations):
        print('== Iteration', i, '==')
        algo.workers.foreach_worker(
            lambda w: w.foreach_env(
                lambda env: env.set_iteration(i, n_iterations)
            )
        )
        print_training_results(algo.train())    algo = (
        PPOConfig(Baseline)
            .rollouts(num_rollout_workers=0, rollout_fragment_length=50, batch_mode='complete_episodes')
            .framework('tf2')
            .resources(num_gpus=0)
            .training(
                train_batch_size=2000,
                gamma=0.99
            )
            .multi_agent(
                policies=filter_keys(policies, ['baseline_producer', 'baseline_consumer']),
                policy_mapping_fn=lambda agent_id, episode, worker:
                                    'baseline_producer' if Utils.is_producer_agent(agent_id) else 'baseline_consumer',
                policies_to_train=['baseline_producer', 'baseline_consumer']
            )
            .environment(
                env=WorldOfSupplyEnv,
                env_config=env_config,
                disable_env_checking=True
            )
            .build()
    )
    
    for i in range(n_iterations):
        print('== Iteration', i, '==')
        algo.workers.foreach_worker(
            lambda w: w.foreach_env(
                lambda env: env.set_iteration(i, n_iterations)
            )
        )
        print_training_results(algo.train())

Unfortunately, it throws the following error.

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
Cell In[50], line 11
      8 wsrt.print_model_summaries()
     10 # Policy training
---> 11 algo = wsrt.train_baseline(n_iterations=1)
     12 #algo = wsrt.train_ppo(n_iterations=600)

File ~/GridDynamics/tensor-house/supply-chain/world-of-supply/world_of_supply_rllib_training.py:170, in train_baseline(n_iterations)
    164     print('== Iteration', i, '==')
    165     algo.workers.foreach_worker(
    166         lambda w: w.foreach_env(
    167             lambda env: env.set_iteration(i, n_iterations)
    168         )
    169     )
--> 170     print_training_results(algo.train())
    172 return algo

File /opt/conda/lib/python3.11/site-packages/ray/tune/trainable/trainable.py:342, in Trainable.train(self)
    340 except Exception as e:
    341     skipped = skip_exceptions(e)
--> 342     raise skipped from exception_cause(skipped)
    344 assert isinstance(result, dict), "step() needs to return a dict."
    346 # We do not modify internal state nor update this result if duplicate.

File /opt/conda/lib/python3.11/site-packages/ray/tune/trainable/trainable.py:339, in Trainable.train(self)
    337 start = time.time()
    338 try:
--> 339     result = self.step()
    340 except Exception as e:
    341     skipped = skip_exceptions(e)

File /opt/conda/lib/python3.11/site-packages/ray/rllib/algorithms/algorithm.py:852, in Algorithm.step(self)
    844     (
    845         results,
    846         train_iter_ctx,
    847     ) = self._run_one_training_iteration_and_evaluation_in_parallel()
    848 # - No evaluation necessary, just run the next training iteration.
    849 # - We have to evaluate in this training iteration, but no parallelism ->
    850 #   evaluate after the training iteration is entirely done.
    851 else:
--> 852     results, train_iter_ctx = self._run_one_training_iteration()
    854 # Sequential: Train (already done above), then evaluate.
    855 if evaluate_this_iter and not self.config.evaluation_parallel_to_training:

File /opt/conda/lib/python3.11/site-packages/ray/rllib/algorithms/algorithm.py:3042, in Algorithm._run_one_training_iteration(self)
   3040 with self._timers[TRAINING_ITERATION_TIMER]:
   3041     if self.config._disable_execution_plan_api:
-> 3042         results = self.training_step()
   3043     else:
   3044         results = next(self.train_exec_impl)

File /opt/conda/lib/python3.11/site-packages/ray/rllib/algorithms/ppo/ppo.py:540, in PPO.training_step(self)
    536 # For each policy: Update KL scale and warn about possible issues
    537 for policy_id, policy_info in train_results.items():
    538     # Update KL loss with dynamic scaling
    539     # for each (possibly multiagent) policy we are training
--> 540     kl_divergence = policy_info[LEARNER_STATS_KEY].get("kl")
    541     self.get_policy(policy_id).update_kl(kl_divergence)
    543     # Warn about excessively high value function loss

KeyError: 'learner_stats'

I tried to understand How to Customize Policies but I am failing to understand the way to create the policy in a way that RLLib can use with TF2.

I am fairly new to RL and RLLib…
Any hint would be great.

Below is the definition of the custom policy

class SimplePolicy(Policy):
    def __init__(self, observation_space, action_space, config):
        Policy.__init__(self, observation_space, action_space, config)
        self.action_space_shape = action_space.shape
        self.n_products = config['number_of_products']
        self.n_sources = config['number_of_sources']
    
    def compute_actions(self,
                        obs_batch,
                        state_batches=None,
                        prev_action_batch=None,
                        prev_reward_batch=None,
                        info_batch=None,
                        episodes=None,
                        **kwargs): 
        
        if info_batch is None:
            action_dict = [ self._action(f_state, None) for f_state in obs_batch ], [], {}  
        else:    
            action_dict = [self._action(f_state, f_state_info) for f_state, f_state_info in zip(obs_batch, info_batch)], [], {}
        
        return action_dict

    # no learning
    def learn_on_batch(self, samples):
        return {}
    
    def get_weights(self):
        return {}
    
    def set_weights(self, weights):
        pass
    
    def get_config_from_env(env):
        return {'facility_types': env.facility_types, 
                'number_of_products': env.n_products(),
                'number_of_sources': env.max_sources_per_facility}

JP