Hello! I am trying to migrate the legacy RLLib code based on build_trainer() below. It is from this repo.
HandCodedTrainer = tt.build_trainer("HandCoded", wsr.SimplePolicy)
ext_conf = {
"multiagent": {
"policies": filter_keys(policies, ['baseline_producer', 'baseline_consumer']),
"policy_mapping_fn": lambda agent_id: 'baseline_producer' if Utils.is_producer_agent(agent_id) else 'baseline_consumer',
"policies_to_train": ['baseline_producer', 'baseline_consumer']
}
}
handcoded_trainer = HandCodedTrainer(
env = wsr.WorldOfSupplyEnv,
config = dict(base_trainer_config, **ext_conf))
for i in range(n_iterations):
print("== Iteration", i, "==")
print_training_results(handcoded_trainer.train())
Below is the revised code for RLLib 2.9
class Baseline(PPO):
@classmethod
def get_default_policy_class(cls, config):
return SimplePolicy
and
algo = (
PPOConfig(Baseline)
.rollouts(num_rollout_workers=0, rollout_fragment_length=50, batch_mode='complete_episodes')
.framework('tf2')
.resources(num_gpus=0)
.training(
train_batch_size=2000,
gamma=0.99
)
.multi_agent(
policies=filter_keys(policies, ['baseline_producer', 'baseline_consumer']),
policy_mapping_fn=lambda agent_id, episode, worker:
'baseline_producer' if Utils.is_producer_agent(agent_id) else 'baseline_consumer',
policies_to_train=['baseline_producer', 'baseline_consumer']
)
.environment(
env=WorldOfSupplyEnv,
env_config=env_config,
disable_env_checking=True
)
.build()
)
for i in range(n_iterations):
print('== Iteration', i, '==')
algo.workers.foreach_worker(
lambda w: w.foreach_env(
lambda env: env.set_iteration(i, n_iterations)
)
)
print_training_results(algo.train()) algo = (
PPOConfig(Baseline)
.rollouts(num_rollout_workers=0, rollout_fragment_length=50, batch_mode='complete_episodes')
.framework('tf2')
.resources(num_gpus=0)
.training(
train_batch_size=2000,
gamma=0.99
)
.multi_agent(
policies=filter_keys(policies, ['baseline_producer', 'baseline_consumer']),
policy_mapping_fn=lambda agent_id, episode, worker:
'baseline_producer' if Utils.is_producer_agent(agent_id) else 'baseline_consumer',
policies_to_train=['baseline_producer', 'baseline_consumer']
)
.environment(
env=WorldOfSupplyEnv,
env_config=env_config,
disable_env_checking=True
)
.build()
)
for i in range(n_iterations):
print('== Iteration', i, '==')
algo.workers.foreach_worker(
lambda w: w.foreach_env(
lambda env: env.set_iteration(i, n_iterations)
)
)
print_training_results(algo.train())
Unfortunately, it throws the following error.
---------------------------------------------------------------------------
KeyError Traceback (most recent call last)
Cell In[50], line 11
8 wsrt.print_model_summaries()
10 # Policy training
---> 11 algo = wsrt.train_baseline(n_iterations=1)
12 #algo = wsrt.train_ppo(n_iterations=600)
File ~/GridDynamics/tensor-house/supply-chain/world-of-supply/world_of_supply_rllib_training.py:170, in train_baseline(n_iterations)
164 print('== Iteration', i, '==')
165 algo.workers.foreach_worker(
166 lambda w: w.foreach_env(
167 lambda env: env.set_iteration(i, n_iterations)
168 )
169 )
--> 170 print_training_results(algo.train())
172 return algo
File /opt/conda/lib/python3.11/site-packages/ray/tune/trainable/trainable.py:342, in Trainable.train(self)
340 except Exception as e:
341 skipped = skip_exceptions(e)
--> 342 raise skipped from exception_cause(skipped)
344 assert isinstance(result, dict), "step() needs to return a dict."
346 # We do not modify internal state nor update this result if duplicate.
File /opt/conda/lib/python3.11/site-packages/ray/tune/trainable/trainable.py:339, in Trainable.train(self)
337 start = time.time()
338 try:
--> 339 result = self.step()
340 except Exception as e:
341 skipped = skip_exceptions(e)
File /opt/conda/lib/python3.11/site-packages/ray/rllib/algorithms/algorithm.py:852, in Algorithm.step(self)
844 (
845 results,
846 train_iter_ctx,
847 ) = self._run_one_training_iteration_and_evaluation_in_parallel()
848 # - No evaluation necessary, just run the next training iteration.
849 # - We have to evaluate in this training iteration, but no parallelism ->
850 # evaluate after the training iteration is entirely done.
851 else:
--> 852 results, train_iter_ctx = self._run_one_training_iteration()
854 # Sequential: Train (already done above), then evaluate.
855 if evaluate_this_iter and not self.config.evaluation_parallel_to_training:
File /opt/conda/lib/python3.11/site-packages/ray/rllib/algorithms/algorithm.py:3042, in Algorithm._run_one_training_iteration(self)
3040 with self._timers[TRAINING_ITERATION_TIMER]:
3041 if self.config._disable_execution_plan_api:
-> 3042 results = self.training_step()
3043 else:
3044 results = next(self.train_exec_impl)
File /opt/conda/lib/python3.11/site-packages/ray/rllib/algorithms/ppo/ppo.py:540, in PPO.training_step(self)
536 # For each policy: Update KL scale and warn about possible issues
537 for policy_id, policy_info in train_results.items():
538 # Update KL loss with dynamic scaling
539 # for each (possibly multiagent) policy we are training
--> 540 kl_divergence = policy_info[LEARNER_STATS_KEY].get("kl")
541 self.get_policy(policy_id).update_kl(kl_divergence)
543 # Warn about excessively high value function loss
KeyError: 'learner_stats'
I tried to understand How to Customize Policies but I am failing to understand the way to create the policy in a way that RLLib can use with TF2.
I am fairly new to RL and RLLib…
Any hint would be great.
Below is the definition of the custom policy
class SimplePolicy(Policy):
def __init__(self, observation_space, action_space, config):
Policy.__init__(self, observation_space, action_space, config)
self.action_space_shape = action_space.shape
self.n_products = config['number_of_products']
self.n_sources = config['number_of_sources']
def compute_actions(self,
obs_batch,
state_batches=None,
prev_action_batch=None,
prev_reward_batch=None,
info_batch=None,
episodes=None,
**kwargs):
if info_batch is None:
action_dict = [ self._action(f_state, None) for f_state in obs_batch ], [], {}
else:
action_dict = [self._action(f_state, f_state_info) for f_state, f_state_info in zip(obs_batch, info_batch)], [], {}
return action_dict
# no learning
def learn_on_batch(self, samples):
return {}
def get_weights(self):
return {}
def set_weights(self, weights):
pass
def get_config_from_env(env):
return {'facility_types': env.facility_types,
'number_of_products': env.n_products(),
'number_of_sources': env.max_sources_per_facility}
JP