How severe does this issue affect your experience of using Ray?
- High: It blocks me to complete my task.
Train centralized critic PPO and PPO at the same time
First of all, my ray version is 0.8.6, which is probably out of date. I have two types of agents in my environment, I need to train them in the same environment at the same time,
they have different observation space and action space, before I trained them with a PPO, now I want to train one of them with a centralized critic PPO and train the other agent with PPO.
I tried the following method, first, I modeled centralized_critic.py to write a centralized critic PPO, my code is as follows:
import tensorflow as tf
from ray.rllib.models.tf.tf_modelv2 import TFModelV2
from ray.rllib.models.tf.fcnet import FullyConnectedNetwork
from ray.rllib.utils.annotations import override
from ray.rllib.evaluation.postprocessing import compute_advantages, Postprocessing
from ray.rllib.policy.sample_batch import SampleBatch
import numpy as np
from ray.rllib.agents.ppo.ppo_tf_policy import ppo_surrogate_loss as tf_loss
from ray.rllib.agents.ppo.ppo_tf_policy import PPOTFPolicy, KLCoeffMixin
from ray.rllib.policy.tf_policy import LearningRateSchedule, EntropyCoeffSchedule
from ray.rllib.utils.tf_ops import make_tf_callable
from ray.rllib.utils.explained_variance import explained_variance
from ray.rllib.utils.framework import try_import_tf
tf1, tf, tfv = try_import_tf()
class CentralizedCriticModel(TFModelV2):
"""Multi-agent model that implements a centralized value function."""
def __init__(self, obs_space, action_space, num_outputs, model_config, name):
super(CentralizedCriticModel, self).__init__(obs_space, action_space, num_outputs, model_config, name)
# Base of the model
self.model = FullyConnectedNetwork(obs_space, action_space, num_outputs, model_config, name)
self.register_variables(self.model.variables())
# Central VF maps (obs, opp_obs) -> vf_pred
obs = tf.keras.layers.Input(shape=(obs_space.shape[0], ), name="obs")
other_obs = tf.keras.layers.Input(shape=(obs_space.shape[0] * 1, ), name="other_obs")
concat_obs = tf.keras.layers.Concatenate(axis=1)([obs, other_obs])
central_vf_dense = tf.keras.layers.Dense(190, activation=tf.nn.tanh, name="c_vf_dense")(concat_obs)
central_vf_out = tf.keras.layers.Dense(1, activation=None, name="c_vf_out")(central_vf_dense)
self.central_vf = tf.keras.Model(inputs=[obs, other_obs], outputs=central_vf_out)
self.register_variables(self.central_vf.variables)
@override(TFModelV2)
def forward(self, input_dict, state, seq_lens):
return self.model.forward(input_dict, state, seq_lens)
def central_value_function(self, obs, other_obs):
return tf.reshape(self.central_vf([obs, other_obs]), [-1])
@override(TFModelV2)
def value_function(self):
return self.model.value_function()
def centralized_critic_postprocessing(policy, sample_batch, other_agent_batches=None, episode=None):
if other_agent_batches:
neighbor_batches = list(other_agent_batches.values())
other_obs = np.concatenate([batch[1][SampleBatch.CUR_OBS] for batch in neighbor_batches], axis=1)
for i, (agent_id, neighbor) in enumerate(neighbor_batches):
print(f"Neighbor batch index {i}, agent ID {agent_id} dimensions: {neighbor[SampleBatch.CUR_OBS].shape}")
sample_batch["neighbor_obs"] = other_obs
sample_batch[SampleBatch.VF_PREDS] = policy.compute_central_vf(sample_batch[SampleBatch.CUR_OBS], other_obs)
else:
sample_batch["neighbor_obs"] = np.zeros_like(sample_batch[SampleBatch.CUR_OBS])
sample_batch[SampleBatch.VF_PREDS] = np.zeros_like(sample_batch[SampleBatch.REWARDS])
last_r = sample_batch[SampleBatch.VF_PREDS][-1] if not sample_batch["dones"][-1] else 0.0
return compute_advantages(sample_batch, last_r, policy.config["gamma"], policy.config["lambda"], use_gae=policy.config["use_gae"])
def loss_with_central_critic(policy, model, dist_class, train_batch):
CentralizedValueMixin.__init__(policy)
vf_saved = model.value_function
model.value_function = lambda: policy.model.central_value_function(train_batch[SampleBatch.CUR_OBS], train_batch["neighbor_obs"])
policy._central_value_out = model.value_function()
loss = tf_loss(policy, model, dist_class, train_batch)
model.value_function = vf_saved
return loss
class CentralizedValueMixin:
def __init__(self):
self.compute_central_vf = make_tf_callable(self.get_session())(self.model.central_value_function)
def setup_tf_mixins(policy, obs_space, action_space, config):
KLCoeffMixin.__init__(policy, config)
EntropyCoeffSchedule.__init__(policy, config["entropy_coeff"], config["entropy_coeff_schedule"])
LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"])
def central_vf_stats(policy, train_batch, grads):
return {
"vf_explained_var": explained_variance(
train_batch[Postprocessing.VALUE_TARGETS],
policy._central_value_out),
}
CCPPOTFPolicy = PPOTFPolicy.with_updates(
name="CCPPOTFPolicy",
postprocess_fn=centralized_critic_postprocessing,
loss_fn=loss_with_central_critic,
before_loss_init=setup_tf_mixins,
grad_stats_fn=central_vf_stats,
mixins=[LearningRateSchedule, EntropyCoeffSchedule, KLCoeffMixin, CentralizedValueMixin]
)
Then the two policies are mapped to different agents in the policy mapping:
POLICY_GRAPHS = {'cav': (PPOTFPolicy, obs_space_av, act_space_av, {}),
'tl': (CCPPOTFPolicy, obs_space_tl, act_space_tl, {"model": {"custom_model": "cc_model"}})}
def policy_mapping_fn(agent_id):
if agent_id.startswith("center"):
return "tl"
else:
return "cav"
policies_to_train = ["cav", "tl"]
Then use the following training script to train:
import argparse
import json
import sys
from copy import deepcopy
from flow.utils.rllib import FlowParamsEncoder
from flow.utils.registry import make_create_env
from ray.rllib.models import ModelCatalog
from exp_configs.rl.multiagent.centralized_critic_model import CentralizedCriticModel
def parse_args(args):
"""Parse training options user can specify in command line.
Returns
-------
argparse.Namespace
the output parser object
"""
parser = argparse.ArgumentParser(
formatter_class=argparse.RawDescriptionHelpFormatter,
description="Parse argument used when running a Flow simulation.",
epilog="python train_ppo.py EXP_CONFIG")
# required input parameters
parser.add_argument(
'exp_config', type=str,
help='Name of the experiment configuration file, as located in '
'exp_configs/rl/singleagent or exp_configs/rl/multiagent.')
# optional input parameters
parser.add_argument(
'--rl_trainer', type=str, default="rllib",
help='the RL trainer to use. either rllib or Stable-Baselines')
parser.add_argument(
'--num_cpus', type=int, default=1,
help='How many CPUs to use')
parser.add_argument(
'--num_steps', type=int, default=999,
help='How many total steps to perform learning over')
parser.add_argument(
'--rollout_size', type=int, default=1000,
help='How many steps are in a training batch.')
parser.add_argument(
'--checkpoint_path', type=str, default=None,
help='Directory with checkpoint to restore training from.')
return parser.parse_known_args(args)[0]
def setup_exps_rllib(flow_params,
n_cpus,
n_rollouts,
policy_graphs=None,
policy_mapping_fn=None,
policies_to_train=None):
"""Return the relevant components of an RLlib experiment.
Parameters
----------
flow_params : dict
flow-specific parameters (see flow/utils/registry.py)
n_cpus : int
number of CPUs to run the experiment over
n_rollouts : int
number of rollouts per training iteration
policy_graphs : dict, optional
policy_mapping_fn : function, optional
policies_to_train : list of str, optional
set in module in exp_configs/rl/multiagent
Returns
-------
str
name of the training algorithm
str
name of the gym environment to be trained
dict
training configuration parameters
"""
from ray import tune
from ray.tune.registry import register_env
try:
from ray.rllib.agents.agent import get_agent_class
except ImportError:
from ray.rllib.agents.registry import get_agent_class
horizon = flow_params['env'].horizon
alg_run = "PPO"
agent_cls = get_agent_class(alg_run)
config = deepcopy(agent_cls._default_config)
config["num_workers"] = n_cpus
config["train_batch_size"] = horizon * n_rollouts
config["gamma"] = 0.999 # discount rate
config["model"].update({"fcnet_hiddens": [32, 32, 32]})
config["use_gae"] = True
config["lambda"] = 0.97
config["kl_target"] = 0.02
config["num_sgd_iter"] = 10
config["horizon"] = horizon
config["num_gpus"] = 3
config["timesteps_per_iteration"] = horizon * n_rollouts
config['no_done_at_end'] = True
config['log_level'] = "ERROR"
# save the flow params for replay, params.json file
flow_json = json.dumps(
flow_params, cls=FlowParamsEncoder, sort_keys=True, indent=4)
config['env_config']['flow_params'] = flow_json
config['env_config']['run'] = alg_run
# multiagent configuration
if policy_graphs is not None:
print("policy_graphs", policy_graphs)
config['multiagent'].update({'policies': policy_graphs})
if policy_mapping_fn is not None:
config['multiagent'].update(
{'policy_mapping_fn': tune.function(policy_mapping_fn)})
if policies_to_train is not None:
config['multiagent'].update({'policies_to_train': policies_to_train})
create_env, gym_name = make_create_env(params=flow_params)
# Register as rllib env
register_env(gym_name, create_env)
return alg_run, gym_name, config
def train_rllib(submodule, flags):
"""Train policies using the PPO algorithm in RLlib."""
import ray
from ray.tune import run_experiments
flow_params = submodule.flow_params
n_cpus = submodule.N_CPUS
n_rollouts = submodule.N_ROLLOUTS
policy_graphs = getattr(submodule, "POLICY_GRAPHS", None)
policy_mapping_fn = getattr(submodule, "policy_mapping_fn", None)
policies_to_train = getattr(submodule, "policies_to_train", None)
alg_run, gym_name, config = setup_exps_rllib(
flow_params, n_cpus, n_rollouts,
policy_graphs, policy_mapping_fn, policies_to_train)
ray.init(num_cpus=n_cpus + 1) # , object_store_memory=200 * 1024 * 1024
exp_config = {
"run": alg_run,
"env": gym_name,
"config": {
**config
},
"checkpoint_freq": 5,
"checkpoint_at_end": True,
"max_failures": 999,
"stop": {
"training_iteration": flags.num_steps,
},
}
if flags.checkpoint_path is not None:
exp_config['restore'] = flags.checkpoint_path
run_experiments({flow_params["exp_tag"]: exp_config})
def main(args):
"""Perform the training operations."""
# Parse script-level arguments (not including package arguments).
ModelCatalog.register_custom_model("cc_model", CentralizedCriticModel)
flags = parse_args(args)
# Import relevant information from the exp_config script.
module = __import__(
"exp_configs.rl.singleagent", fromlist=[flags.exp_config])
module_ma = __import__(
"exp_configs.rl.multiagent", fromlist=[flags.exp_config])
# Import the submodule containing the specified exp_config and determine
# whether the environment is single agent or multiagent.
if hasattr(module, flags.exp_config):
submodule = getattr(module, flags.exp_config)
elif hasattr(module_ma, flags.exp_config):
submodule = getattr(module_ma, flags.exp_config)
assert flags.rl_trainer.lower() in ["rllib", "h-baselines"], \
"Currently, multiagent experiments are only supported through "\
"RLlib. Try running this experiment using RLlib: " \
"'python train_ppo.py EXP_CONFIG'"
else:
raise ValueError("Unable to find experiment config.")
# Perform the training operation.
train_rllib(submodule, flags)
if __name__ == "__main__":
main(sys.argv[1:])
When trying to train, the following problems arose:
/flow/examples/exp_configs/rl/multiagent/centralized_critic_model.py", line 53, in centralized_critic_postprocessing
other_obs = np.concatenate([batch[1][SampleBatch.CUR_OBS] for batch in neighbor_batches], axis=-1)
File "<__array_function__ internals>", line 6, in concatenate
ValueError: all the input array dimensions for the concatenation axis must match exactly, but along dimension 0, the array at index 0 has size 200 and the array at index 1 has size 94
By printing information in the centralized_critic_postprocessing function, I find the following information:
(pid=18217) Neighbor batch index 0, agent ID <ray.rllib.policy.tf_policy_template.CCPPOTFPolicy object at 0x7f60e6676748> dimensions: (200, 27)
(pid=18217) Neighbor batch index 1, agent ID <ray.rllib.policy.tf_policy_template.PPOTFPolicy object at 0x7f60d4673550> dimensions: (144, 9)
(pid=18217) Neighbor batch index 2, agent ID <ray.rllib.policy.tf_policy_template.PPOTFPolicy object at 0x7f60d4673550> dimensions: (76, 9)
(pid=18217) Neighbor batch index 3, agent ID <ray.rllib.policy.tf_policy_template.PPOTFPolicy object at 0x7f60d4673550> dimensions: (40, 9)
(pid=18217) Neighbor batch index 4, agent ID <ray.rllib.policy.tf_policy_template.PPOTFPolicy object at 0x7f60d4673550> dimensions: (31, 9)
(pid=18217) Neighbor batch index 5, agent ID <ray.rllib.policy.tf_policy_template.PPOTFPolicy object at 0x7f60d4673550> dimensions: (64, 9)
(pid=18217) Neighbor batch index 6, agent ID <ray.rllib.policy.tf_policy_template.PPOTFPolicy object at 0x7f60d4673550> dimensions: (84, 9)
(pid=18217) Neighbor batch index 7, agent ID <ray.rllib.policy.tf_policy_template.PPOTFPolicy object at 0x7f60d4673550> dimensions: (115, 9)
(pid=18217) Neighbor batch index 8, agent ID <ray.rllib.policy.tf_policy_template.PPOTFPolicy object at 0x7f60d4673550> dimensions: (13, 9)
(pid=18217) Neighbor batch index 9, agent ID <ray.rllib.policy.tf_policy_template.PPOTFPolicy object at 0x7f60d4673550> dimensions: (33, 9)
In my opinion, this is because agents using PPOTFPolicy enter the centralized_critic_postprocessing function, resulting in dimension mismatch. I think agents using PPOTFPolicy should not enter this function using CC policy. I don’t know why this happens.
I tried something else. By copying two_trainer_workflow.py, I gave it a try and defined a new trainer, as follows:
def custom_training_workflow(workers: WorkerSet, config: dict):
local_replay_buffer = LocalReplayBuffer(
num_shards=1,
learning_starts=1000,
buffer_size=50000,
replay_batch_size=64)
def add_cc_metrics(batch):
print("CC policy learning on samples from",
batch.policy_batches.keys(), "env steps", batch.count,
"agent steps", batch.total())
metrics = _get_shared_metrics()
metrics.counters["agent_steps_trained_CC"] += batch.total()
return batch
def add_ppo_metrics(batch):
print("PPO policy learning on samples from",
batch.policy_batches.keys(), "env steps", batch.count,
"agent steps", batch.total())
metrics = _get_shared_metrics()
metrics.counters["agent_steps_trained_PPO"] += batch.total()
return batch
# Generate common experiences.
rollouts = ParallelRollouts(workers, mode="bulk_sync")
r1, r2 = rollouts.duplicate(n=2)
# CC sub-flow.
cc_train_op = r1.for_each(SelectExperiences(["cc_policy"])) \
.combine(ConcatBatches(min_batch_size=200)) \
.for_each(add_cc_metrics) \
.for_each(StandardizeFields(["advantages"])) \
.for_each(TrainOneStep(workers, policies=["cc_policy"], num_sgd_iter=10, sgd_minibatch_size=128))
# PPO sub-flow.
ppo_train_op = r2.for_each(SelectExperiences(["ppo_policy"])) \
.combine(ConcatBatches(min_batch_size=200)) \
.for_each(add_ppo_metrics) \
.for_each(StandardizeFields(["advantages"])) \
.for_each(TrainOneStep(workers, policies=["ppo_policy"], num_sgd_iter=10, sgd_minibatch_size=128))
# Combined training flow
train_op = Concurrently(
[cc_train_op, ppo_train_op], mode="async", output_indexes=[1])
return StandardMetricsReporting(train_op, workers, config)
CustomTrainer = build_trainer(
name="PPO_CC_MultiAgent",
default_config=PPO_CONFIG,
default_policy=None,
execution_plan=custom_training_workflow
)
The new policy is mapped as follows:
POLICY_GRAPHS = {
"ppo_policy": (PPOTFPolicy, obs_space_av, act_space_av, {}),
"cc_policy": (CCPPOTFPolicy, obs_space_tl, act_space_tl, {"model": {"custom_model": "cc_model"}})
}
def policy_mapping_fn(agent_id):
# Map a policy in RLlib.
if agent_id.startswith("center"):
return "cc_policy"
else:
return "ppo_policy"
policies_to_train = ["ppo_policy", "cc_policy"]
Then I changed the configuration:
config = deepcopy(PPO_CONFIG)
config["num_workers"] = n_cpus
config["train_batch_size"] = horizon * n_rollouts
config["gamma"] = 0.999 # discount rate
#config["model"].update({"custom_model": "cc_model", })
config["model"].update({"fcnet_hiddens": [32, 32, 32]})
config["use_gae"] = True
config["lambda"] = 0.97
config["kl_target"] = 0.02
config["num_sgd_iter"] = 10
config["horizon"] = horizon
config["num_gpus"] = 1
config["timesteps_per_iteration"] = horizon * n_rollouts
config['no_done_at_end'] = True
config['log_level'] = "ERROR"
# save the flow params for replay, params.json file
flow_json = json.dumps(
flow_params, cls=FlowParamsEncoder, sort_keys=True, indent=4)
config['env_config']['flow_params'] = flow_json
#config['env_config']['run'] = alg_run
# multiagent configuration
if policy_graphs is not None:
print("policy_graphs", policy_graphs)
config['multiagent'].update({'policies': policy_graphs})
if policy_mapping_fn is not None:
config['multiagent'].update(
{'policy_mapping_fn': tune.function(policy_mapping_fn)})
if policies_to_train is not None:
config['multiagent'].update({'policies_to_train': policies_to_train})
create_env, gym_name = make_create_env(params=flow_params)
# Register as rllib env
register_env(gym_name, create_env)
return gym_name, config
exp_config = {
"run": CustomTrainer,
"env": gym_name,
"config": {
**config
},
"checkpoint_freq": 5,
"checkpoint_at_end": True,
"max_failures": 999,
"stop": {
"training_iteration": flags.num_steps,
},
}
if flags.checkpoint_path is not None:
exp_config['restore'] = flags.checkpoint_path
run_experiments({flow_params["exp_tag"]: exp_config})
I then made a second attempt and got the same error as the first attempt
ValueError: all the input array dimensions for the concatenation axis must match exactly, but along dimension 0, the array at index 0 has size 200 and the array at index 1 has size 94
The print message is still:
(pid=18217) Neighbor batch index 0, agent ID <ray.rllib.policy.tf_policy_template.CCPPOTFPolicy object at 0x7f60e6676748> dimensions: (200, 27)
(pid=18217) Neighbor batch index 1, agent ID <ray.rllib.policy.tf_policy_template.PPOTFPolicy object at 0x7f60d4673550> dimensions: (144, 9)
(pid=18217) Neighbor batch index 2, agent ID <ray.rllib.policy.tf_policy_template.PPOTFPolicy object at 0x7f60d4673550> dimensions: (76, 9)
(pid=18217) Neighbor batch index 3, agent ID <ray.rllib.policy.tf_policy_template.PPOTFPolicy object at 0x7f60d4673550> dimensions: (40, 9)
(pid=18217) Neighbor batch index 4, agent ID <ray.rllib.policy.tf_policy_template.PPOTFPolicy object at 0x7f60d4673550> dimensions: (31, 9)
(pid=18217) Neighbor batch index 5, agent ID <ray.rllib.policy.tf_policy_template.PPOTFPolicy object at 0x7f60d4673550> dimensions: (64, 9)
(pid=18217) Neighbor batch index 6, agent ID <ray.rllib.policy.tf_policy_template.PPOTFPolicy object at 0x7f60d4673550> dimensions: (84, 9)
(pid=18217) Neighbor batch index 7, agent ID <ray.rllib.policy.tf_policy_template.PPOTFPolicy object at 0x7f60d4673550> dimensions: (115, 9)
(pid=18217) Neighbor batch index 8, agent ID <ray.rllib.policy.tf_policy_template.PPOTFPolicy object at 0x7f60d4673550> dimensions: (13, 9)
(pid=18217) Neighbor batch index 9, agent ID <ray.rllib.policy.tf_policy_template.PPOTFPolicy object at 0x7f60d4673550> dimensions: (33, 9)
I’m new to rllib and I want to know what went wrong