TF error when restoring from checkpoint, multi-agent

Hello,

My goal is to evaluate performance of trained model from checkpoints.
Using tune.run() I trained multi-agents with my own custom environment class. Checkpoint is only at the end of each run.

I’m using the same config dict from the training code in the following eval script:

ray.init()

import ray.rllib.agents.a3c as a3c
agent = a3c.A3CTrainer(config=config, env=AgentEnv)
agent.restore(args.checkpoint_path)

# instantiate env class
env = AgentEnv(env_config)

# run until episode ends
episode_reward = 0
done = False
episode_length = 10
length_count = 0
obs = env.reset()
while not done and (length_count <= episode_length):
    action = agent.compute_action(obs, policy_id="antenna_1", explore=False)
    obs, reward, done, info = env.step(action)
    length_count += 1
    print(length_count, reward)



ray.shutdown()

I am getting the following error message. I’ve shorten the array values to shorten this post.

2021-03-29 08:02:00,955 INFO trainable.py:379 -- Current state after restoring: {'_iteration': 4000, '_timesteps_total
': None, '_time_total': 13776.453395605087, '_episodes_total': 4806}
2021-03-29 08:02:01,009 ERROR tf_run_builder.py:47 -- Error fetching: [<tf.Tensor 'antenna_1/cond_1/Merge:0' shape=(?,
) dtype=int64>, {'action_prob': <tf.Tensor 'antenna_1/Exp:0' shape=(?,) dtype=float32>, 'action_logp': <tf.Tensor 'ant
enna_1/cond_2/Merge:0' shape=(?,) dtype=float32>, 'action_dist_inputs': <tf.Tensor 'antenna_1/model_1/fc_out/BiasAdd:0
' shape=(?, 5) dtype=float32>, 'vf_preds': <tf.Tensor 'antenna_1/Reshape_1:0' shape=(?,) dtype=float32>}], feed_dict={
<tf.Tensor 'antenna_1/obs:0' shape=(?, 301, 9) dtype=float32>: [array({'antenna_0': array([[0.        , 0.3125    , 0.
    0.33333334]]), 'antenna_1': array([[0.        , 0.3125    , 0.375     , ..., 0.        , 0.        ,
   [0.        , 0.66666669, 0.58333331, ..., 0.25      , 0.16666667,
    0.33333334]]), 'antenna_2': array([[0.        , 0.3125    , 0.375     , ..., 0.        , 0.        ,
    0.        ],
   [0.        , 0.66666669, 0.58333331, ..., 0.25      , 0.16666667,
    0.33333334]]), 'antenna_3': array([[0.        , 0.3125    , 0.375     , ..., 0.        , 0.        ,
    0.        ],
   [0.        , 0.66666669, 0.58333331, ..., 0.25      , 0.16666667,
    0.33333334]]), 'antenna_4': array([[0.        , 0.3125    , 0.375     , ..., 0.        , 0.        ,
    0.        ],
   [0.        , 0.66666669, 0.58333331, ..., 0.25      , 0.16666667,
    0.33333334]]), 'antenna_5': array([[0.        , 0.3125    , 0.375     , ..., 0.        , 0.        ,
    0.        ],
   [0.        , 0.66666669, 0.58333331, ..., 0.25      , 0.16666667,
    0.33333334]]), 'antenna_6': array([[0.        , 0.3125    , 0.375     , ..., 0.        , 0.        ,
    0.        ],
   [0.        , 0.66666669, 0.58333331, ..., 0.25      , 0.16666667,
    0.33333334]]), 'antenna_7': array([[0.        , 0.3125    , 0.375     , ..., 0.        , 0.        ,
    0.        ],
   [0.        , 0.66666669, 0.58333331, ..., 0.25      , 0.16666667,
    0.33333334]]), 'antenna_8': array([[0.        , 0.3125    , 0.375     , ..., 0.        , 0.        ,
    0.        ],
   [0.        , 0.66666669, 0.58333331, ..., 0.25      , 0.16666667,
    0.33333334]])}, dtype=object)], <tf.Tensor 'antenna_1/is_training:0' shape=() dtype=bool>: False, <tf.Tensor '
antenna_1/is_exploring:0' shape=() dtype=bool>: True, <tf.Tensor 'antenna_1/timestep:0' shape=() dtype=int64>: 0}
Traceback (most recent call last):
  File "/opt/conda/lib/python3.7/site-packages/ray/rllib/utils/tf_run_builder.py", line 44, in get
    self.feed_dict, os.environ.get("TF_TIMELINE_DIR"))
  File "/opt/conda/lib/python3.7/site-packages/ray/rllib/utils/tf_run_builder.py", line 89, in run_timeline
    fetches = sess.run(ops, feed_dict=feed_dict)
  File "/opt/conda/lib/python3.7/site-packages/tensorflow_core/python/client/session.py", line 956, in run
    run_metadata_ptr)
  File "/opt/conda/lib/python3.7/site-packages/tensorflow_core/python/client/session.py", line 1149, in _run
    np_val = np.asarray(subfeed_val, dtype=subfeed_dtype)
  File "/opt/conda/lib/python3.7/site-packages/numpy/core/_asarray.py", line 85, in asarray
    return array(a, dtype, copy=False, order=order)
TypeError: float() argument must be a string or a number, not 'dict'
Traceback (most recent call last):
  File "./eval_ma.py", line 125, in <module>
    action = agent.compute_action(obs, policy_id="antenna_1")
  File "/opt/conda/lib/python3.7/site-packages/ray/rllib/agents/trainer.py", line 857, in compute_action
    explore=explore)
  File "/opt/conda/lib/python3.7/site-packages/ray/rllib/policy/policy.py", line 219, in compute_single_action
    timestep=timestep)
  File "/opt/conda/lib/python3.7/site-packages/ray/rllib/policy/tf_policy.py", line 340, in compute_actions
    fetched = builder.get(to_fetch)
  File "/opt/conda/lib/python3.7/site-packages/ray/rllib/utils/tf_run_builder.py", line 48, in get
    raise e
  File "/opt/conda/lib/python3.7/site-packages/ray/rllib/utils/tf_run_builder.py", line 44, in get
    self.feed_dict, os.environ.get("TF_TIMELINE_DIR"))
  File "/opt/conda/lib/python3.7/site-packages/ray/rllib/utils/tf_run_builder.py", line 89, in run_timeline
    fetches = sess.run(ops, feed_dict=feed_dict)
  File "/opt/conda/lib/python3.7/site-packages/tensorflow_core/python/client/session.py", line 956, in run
    run_metadata_ptr)
  File "/opt/conda/lib/python3.7/site-packages/tensorflow_core/python/client/session.py", line 1149, in _run
    np_val = np.asarray(subfeed_val, dtype=subfeed_dtype)
  File "/opt/conda/lib/python3.7/site-packages/numpy/core/_asarray.py", line 85, in asarray
    return array(a, dtype, copy=False, order=order)
TypeError: float() argument must be a string or a number, not 'dict'

I am using ray v1.2.0. The policy network is the default one. I’m using A3C.

Edit: I’ve traced the code to this: ray/tf_run_builder.py at b87fc1be5505c577f01807fd342e0cdb2e129081 · ray-project/ray · GitHub

But I’m not sure how to debug further. Any help is appreciated!

@sven1977 could you take a look? Looks like we should have an action item here to provide a better error message if nothing else :slight_smile:

It turns out the agent.compute_action() expects non-dict observation, but I pass in dict version because the environment expects that. Here it the working code:

import ray.rllib.agents.a3c as a3c
agent = a3c.A3CTrainer(config=config, env=AgentEnv)
agent.restore(args.checkpoint_path)

# instantiate env class
env = AgentEnv(env_config)

# run until episode ends
done = False
episode_length = 50
length_count = 0
obs = env.reset()
while not done and (length_count <= episode_length):

  action_dict = {}
  for policy_id in config["multiagent"]["policies"].keys():
    action_dict[policy_id] = agent.compute_action(obs[policy_id], policy_id=policy_id)

  obs, reward, done_dict, info = env.step(action_dict)
  length_count += 1
  done = done_dict["__all__"]

Hey @RickLan thanks for posting this. Could you provide a self-sufficient reproduction script? Your script is missing some imports, config dicts, and environments (you can use the RandomEnv in rllib/examples/env/random_env.py if you need to mock one).

import ray
import ray.rllib.agents.ppo as ppo
from ray.tune.logger import pretty_print

from ray.rllib.examples.env.random_env import RandomMultiAgentEnv



num_agents = 2
config = ppo.DEFAULT_CONFIG.copy()
config["num_workers"] = 1
config["env_config"] = {
  "num_agents" : num_agents,
}
env = RandomMultiAgentEnv(config["env_config"])
config["multiagent"] = {
  "policies" : { # (policy_cls, obs_space, act_space, config)
    "{}".format(x): (None, env.observation_space, env.action_space, {}) for x in range(num_agents)
  },
  "policy_mapping_fn": lambda x: "{}".format(x),
}

ray.init()

trainer = ppo.PPOTrainer(config=config, env=RandomMultiAgentEnv)
# Short training
for i in range(3):
  result = trainer.train()
  print(pretty_print(result))

# Evaluate trained model
done = False
count = 0
obs_dict = env.reset()
while not done:
  action_dict = {}
  for policy in config["multiagent"]["policies"].keys():
    # RandomMultiAgentEnv's keys are integers, not strings.
    action_dict[int(policy)] = trainer.compute_action(obs_dict[int(policy)], policy_id=policy)

  obs_dict, reward_dict, done_dict, info = env.step(action_dict)
  count += 1
  done = done_dict["__all__"]
  if done:
    print(count)
    print(reward_dict)
    print(done_dict)


ray.shutdown()

Cheers,

1 Like

Thanks for the script @RickLan! Let me try to reproduce now …

Hey Rick (@RickLan ). I got the following (slightly modified) script working in the current master. Could you take a look? I’m doing the compute_action calls separately (one per agent) as this is required to specify the correct policy for each of the agents. Also, when calling env.step, I have to recompile the newly calculated action dict.

import ray
import ray.rllib.agents.a3c as a3c

from ray.rllib.examples.env.random_env import RandomMultiAgentEnv

num_agents = 2
config = a3c.DEFAULT_CONFIG.copy()
config["num_workers"] = 1
config["env_config"] = {
  "num_agents" : num_agents,
}
env = RandomMultiAgentEnv(config["env_config"])
config["multiagent"] = {
  "policies" : { # (policy_cls, obs_space, act_space, config)
    "{}".format(x): (None, env.observation_space, env.action_space, {}) for x in range(num_agents)
  },
  "policy_mapping_fn": lambda x: "{}".format(x),
}

ray.init()


import ray.rllib.agents.a3c as a3c
agent = a3c.A3CTrainer(config=config, env=RandomMultiAgentEnv)

# run until episode ends
episode_reward = 0
done = False
episode_length = 10
length_count = 0
obs = env.reset()
while not done and (length_count <= episode_length):
    # NOTE: policy_id are strings ("0" or "1"), agent_id are the respective ints (0 or 1).
    # This is due to your "policies" dict using str keys (to identify the two policies used) in the "multiagent" setup.
    action1 = agent.compute_action(obs[0], policy_id="0", explore=False)
    action2 = agent.compute_action(obs[1], policy_id="1", explore=False)
    obs, reward, done, info = env.step({0: action1, 1: action2})
    length_count += 1
    print(length_count, reward)

1 Like

Hi @sven1977 ,

Thanks for getting back on this.

Looks good. As soon as I did that (TF error when restoring from checkpoint, multi-agent - #3 by RickLan), the tensorflow errors went away. Sorry if I did not make that clear.

I was using v1.2.0. Good to know it still works on the master.

1 Like