TF error when restoring from checkpoint, multi-agent

import ray
import ray.rllib.agents.ppo as ppo
from ray.tune.logger import pretty_print

from ray.rllib.examples.env.random_env import RandomMultiAgentEnv



num_agents = 2
config = ppo.DEFAULT_CONFIG.copy()
config["num_workers"] = 1
config["env_config"] = {
  "num_agents" : num_agents,
}
env = RandomMultiAgentEnv(config["env_config"])
config["multiagent"] = {
  "policies" : { # (policy_cls, obs_space, act_space, config)
    "{}".format(x): (None, env.observation_space, env.action_space, {}) for x in range(num_agents)
  },
  "policy_mapping_fn": lambda x: "{}".format(x),
}

ray.init()

trainer = ppo.PPOTrainer(config=config, env=RandomMultiAgentEnv)
# Short training
for i in range(3):
  result = trainer.train()
  print(pretty_print(result))

# Evaluate trained model
done = False
count = 0
obs_dict = env.reset()
while not done:
  action_dict = {}
  for policy in config["multiagent"]["policies"].keys():
    # RandomMultiAgentEnv's keys are integers, not strings.
    action_dict[int(policy)] = trainer.compute_action(obs_dict[int(policy)], policy_id=policy)

  obs_dict, reward_dict, done_dict, info = env.step(action_dict)
  count += 1
  done = done_dict["__all__"]
  if done:
    print(count)
    print(reward_dict)
    print(done_dict)


ray.shutdown()

Cheers,

1 Like