Hi all!
I am trying a self-play based scheme, where I want to have two agents in waterworld environment have a policy that is being trained (“shared_policy_1”) and other 3 agents that sample a policy from a menagerie (set) of the previous policies of the first two agents ( “shared_policy_2”).
My problem is that I see that the weights in the menagerie are overwritten in every iteration by the current weights and I don’ tunderstand why this is happening. (the same happens by using a dictionary instead of a list for men variable)
So instead of having M previous policy weights in the menagerie, I have M times the current weights of training iteartion i.
You can check the .txt files created since I cannot upload them here.
Thanks in advance.
Please find the code attached :
from ray import tune
from ray.rllib.agents.callbacks import DefaultCallbacks
import argparse
import gym
import os
import random
import ray
import numpy as np
from ray.tune.registry import register_env
from ray.rllib.env.pettingzoo_env import PettingZooEnv
from pettingzoo.sisl import waterworld_v3
M = 5 # Menagerie size
men = []
class MyCallbacks(DefaultCallbacks):
def on_train_result(self, *, trainer, result: dict, **kwargs):
print("trainer.train() result: {} -> {} episodes".format(
trainer, result["episodes_this_iter"]))
i = result['training_iteration'] # starts from 1
# the "shared_policy_1" is the only agent being trained
print("training iteration:", i)
global men
if i <= M:
# menagerie initialisation
tmp = trainer.get_policy("shared_policy_1").get_weights()
men.append(tmp)
filename1 = 'file_init_' + str(i) + '.txt'
textfile1 = open(filename1, "w")
for element1 in men:
textfile1.write("############# menagerie entries ##################" + "\n")
textfile1.write(str(element1) + "\n")
textfile1.close()
else:
# the first policy added is erased
men.pop(0)
# add current training policy in the last position of the menagerie
w = trainer.get_policy("shared_policy_1").get_weights()
men.append(w)
# select one policy randomly
sel = random.randint(0, M-1)
trainer.set_weights({"shared_policy_2": men[sel]})
weights = ray.put(trainer.workers.local_worker().save())
trainer.workers.foreach_worker(
lambda w: w.restore(ray.get(weights))
)
filename = 'file' + str(i) + '.txt'
textfile = open(filename, "w")
for element in men:
textfile.write("############# menagerie entries ##################" + "\n")
textfile.write(str(element) + "\n")
# you can mutate the result dict to add new fields to return
result["callback_ok"] = True
if __name__ == "__main__":
ray.init()
def env_creator(args):
return PettingZooEnv(waterworld_v3.env(n_pursuers=5, n_evaders=5))
env = env_creator({})
register_env("waterworld", env_creator)
obs_space = env.observation_space
act_spc = env.action_space
policies = {"shared_policy_1": (None, obs_space, act_spc, {}),
"shared_policy_2": (None, obs_space, act_spc, {})
}
def policy_mapping_fn(agent_id):
if agent_id == "pursuer_0" or "pursuer_1":
return "shared_policy_1"
else:
return "shared_policy_2"
tune.run(
"PPO",
name="PPO self play n = 5, M=5 trial 1",
stop={"episodes_total": 50000},
checkpoint_freq=10,
config={
# Enviroment specific
"env": "waterworld",
# General
"framework": "torch",
"callbacks": MyCallbacks,
"num_gpus": 0,
"num_workers": 0,
# Method specific
"multiagent": {
"policies": policies,
"policies_to_train": ["shared_policy_1"],
"policy_mapping_fn": policy_mapping_fn,
},
},
)