TF error when restoring from checkpoint, multi-agent

It turns out the agent.compute_action() expects non-dict observation, but I pass in dict version because the environment expects that. Here it the working code:

import ray.rllib.agents.a3c as a3c
agent = a3c.A3CTrainer(config=config, env=AgentEnv)
agent.restore(args.checkpoint_path)

# instantiate env class
env = AgentEnv(env_config)

# run until episode ends
done = False
episode_length = 50
length_count = 0
obs = env.reset()
while not done and (length_count <= episode_length):

  action_dict = {}
  for policy_id in config["multiagent"]["policies"].keys():
    action_dict[policy_id] = agent.compute_action(obs[policy_id], policy_id=policy_id)

  obs, reward, done_dict, info = env.step(action_dict)
  length_count += 1
  done = done_dict["__all__"]