import ray
import ray.rllib.agents.ppo as ppo
from ray.tune.logger import pretty_print
from ray.rllib.examples.env.random_env import RandomMultiAgentEnv
num_agents = 2
config = ppo.DEFAULT_CONFIG.copy()
config["num_workers"] = 1
config["env_config"] = {
"num_agents" : num_agents,
}
env = RandomMultiAgentEnv(config["env_config"])
config["multiagent"] = {
"policies" : { # (policy_cls, obs_space, act_space, config)
"{}".format(x): (None, env.observation_space, env.action_space, {}) for x in range(num_agents)
},
"policy_mapping_fn": lambda x: "{}".format(x),
}
ray.init()
trainer = ppo.PPOTrainer(config=config, env=RandomMultiAgentEnv)
# Short training
for i in range(3):
result = trainer.train()
print(pretty_print(result))
# Evaluate trained model
done = False
count = 0
obs_dict = env.reset()
while not done:
action_dict = {}
for policy in config["multiagent"]["policies"].keys():
# RandomMultiAgentEnv's keys are integers, not strings.
action_dict[int(policy)] = trainer.compute_action(obs_dict[int(policy)], policy_id=policy)
obs_dict, reward_dict, done_dict, info = env.step(action_dict)
count += 1
done = done_dict["__all__"]
if done:
print(count)
print(reward_dict)
print(done_dict)
ray.shutdown()
Cheers,