It turns out the agent.compute_action() expects non-dict observation, but I pass in dict version because the environment expects that. Here it the working code:
import ray.rllib.agents.a3c as a3c
agent = a3c.A3CTrainer(config=config, env=AgentEnv)
agent.restore(args.checkpoint_path)
# instantiate env class
env = AgentEnv(env_config)
# run until episode ends
done = False
episode_length = 50
length_count = 0
obs = env.reset()
while not done and (length_count <= episode_length):
action_dict = {}
for policy_id in config["multiagent"]["policies"].keys():
action_dict[policy_id] = agent.compute_action(obs[policy_id], policy_id=policy_id)
obs, reward, done_dict, info = env.step(action_dict)
length_count += 1
done = done_dict["__all__"]