Example code from today’s office hours:
import numpy as np
from pettingzoo.sisl import waterworld_v3
import ray
from ray.tune import CLIReporter
from ray import air, tune
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.env.wrappers.pettingzoo_env import PettingZooEnv
from ray.tune.registry import register_env
from ray.rllib.algorithms.callbacks import DefaultCallbacks
class MyCallbacks(DefaultCallbacks):
def on_train_result(self, *, algorithm, result: dict, **kwargs):
result["custom_metrics"]["policy_reward_mean"] = {
"pursuer_0": result["policy_reward_mean"].get("pursuer_0", np.nan),
"pursuer_1": result["policy_reward_mean"].get("pursuer_1", np.nan),
"pursuer_2": result["policy_reward_mean"].get("pursuer_2", np.nan),
"pursuer_3": result["policy_reward_mean"].get("pursuer_3", np.nan),
}
if __name__ == '__main__':
ray.init(local_mode=True)
def env_creator(args):
return PettingZooEnv(waterworld_v3.env())
dummy_env = env_creator({})
register_env("waterworld", env_creator)
obs_space = dummy_env.observation_space
act_space = dummy_env.action_space
config = PPOConfig()
config.multi_agent(
policies={pid: (None, obs_space, act_space, {}) for pid in
dummy_env.env.agents},
policy_mapping_fn=(lambda agent_id, episode, **kwargs: agent_id),
)
config.rollouts(num_rollout_workers=4)
config.environment(env="waterworld")
config.callbacks(MyCallbacks)
config = config.to_dict()
tune.Tuner(
"PPO",
run_config=air.RunConfig(
stop={"episodes_total": 1,
"custom_metrics/policy_reward_mean/pursuer_0": 0,
"custom_metrics/policy_reward_mean/pursuer_1": 0,
"custom_metrics/policy_reward_mean/pursuer_2": 0,
"custom_metrics/policy_reward_mean/pursuer_3": 0},
checkpoint_config=air.CheckpointConfig(
checkpoint_frequency=1000,
),
progress_reporter=CLIReporter(
metric_columns={
"training_iteration": "training_iteration",
"time_total_s": "time_total_s",
"timesteps_total": "timesteps",
"episodes_this_iter": "episodes_trained",
"custom_metrics/policy_reward_mean/pursuer_0": "m_reward_p_0",
"custom_metrics/policy_reward_mean/pursuer_1": "m_reward_p_1",
"custom_metrics/policy_reward_mean/pursuer_2": "m_reward_p_2",
"custom_metrics/policy_reward_mean/pursuer_3": "m_reward_p_3",
"episode_reward_mean": "mean_reward_sum",
},
sort_by_metric=True,
),
),
param_space=config,
).fit()