- High: It blocks me to complete my task.
Hi. I want to train the agent in multiple environments simultaneously. For example, I want to train the 8 ppo agents in 8 environments. (same architecture, hyperparameters)
So, I tried to implement the distributed training function with ray and RLlib trainer together. But when I ran this code, each agent wasn’t trained in a distributed manner. Is there any lock in the RLlib trainer? How can I train multiple trainers simultaneously?
Here are my codes.
import gym
import ray
from ray.tune.registry import register_env
from ray.rllib.agents.ppo import PPOTrainer, PPOConfig
from ray.tune.logger import pretty_print
env_names = ['CartPole-v0', 'MountainCar-v0', "Taxi-v3", "SpaceInvaders-v0",
'LunarLander-v2', 'Humanoid-v2', 'FrozenLake-v0', 'HandManipulateBlock-v0']
def env_creator(env_config):
env_name = env_config["env"]
SEED = env_config["seed"]
env = gym.make(env_name)
env.seed(SEED)
return env
for env_name in env_names:
register_env(env_name, env_creator)
@ray.remote(num_cpus=4, num_gpus=1)
def distributed_trainer(env_name):
config = PPOConfig()
config.training(
gamma=0.99,
lr=0.0005,
train_batch_size=1000,
model={
"fcnet_hiddens": [128, 128],
"fcnet_activation": "tanh",
},
use_gae=True,
lambda_=0.95,
vf_loss_coeff=0.2,
entropy_coeff=0.001,
num_sgd_iter=5,
sgd_minibatch_size=32,
shuffle_sequences=True,
)\
.resources(
num_gpus=1,
num_cpus_per_worker=2,
)\
.framework(
framework='torch'
)\
.environment(
env=env_name,
render_env=True,
env_config = {"env": env_name, "seed": 1}
)\
.rollouts(
num_rollout_workers=2,
num_envs_per_worker=2,
create_env_on_local_worker=False,
rollout_fragment_length=250,
horizon=500,
soft_horizon=False,
no_done_at_end=False,
)\
.evaluation(
evaluation_interval=10,
evaluation_duration=100,
evaluation_duration_unit='auto',
evaluation_num_workers=3,
evaluation_parallel_to_training=True
#evaluation_config=,
#custom_evaluation_function=,
)
print(env_name)
trainer = PPOTrainer(env=env_name, config=config)
for epoch in range(500):
result = trainer.train()
#print(pretty_print(result))
print(f"env: {env_name}, epoch: {epoch}")
if epoch % 10 == 0:
checkpoint = trainer.save()
print("checkpoint saved at", checkpoint)
return 0
distributed_trainier_refs = [distributed_trainer.remote(env_name) for env_name in env_names]
results = ray.get(distributed_trainier_refs)