.compute_actions() for multi agent environment

Hello all, I have a question. I am using DQN algorithm for a custom multi agent environment. After training, I want to test the performance of the model, then I need to use .compute_actions(obs,policy_id = DQNTFPolicy). Because the observation is a dictionary, I need to give it the policy. But when I run the code, I receive this error:
‘tuple’ object has no attribute ‘items’
I don’t know where this error comes from. Here is my setting

ray.init(local_mode=True)

def env_creator(_):
    return AAM_Dispatch()
            
single_env = AAM_Dispatch()
env_name = "AAM_Dispatch"
register_env(env_name, env_creator)

obs_space = single_env.observation_space
act_space = single_env.action_space
num_agents = single_env.num_agents

def gen_policy():
    return (DQNTFPolicy)

policy_graphs = {
            "dqn_policy": (
                gen_policy(),
                obs_space,
                act_space,
                {},
                ),
}


def policy_mapping_fn(agent_id, episode, worker, **kwargs):
    return "dqn_policy"


dqn_config = (
    DQNConfig()
    .environment("AAM_Dispatch")
    .framework("tf")
    .rollouts(observation_filter="MeanStdFilter")
    .training(
        model={"vf_share_layers": True},
        n_step=3,
        gamma=0.99
    )
    .multi_agent(
        policies=policy_graphs,
        policy_mapping_fn=policy_mapping_fn,
        policies_to_train=["dqn_policy"],
    )

    .resources(num_gpus=int(os.environ.get("RLLIB_NUM_GPUS", "0")))
)

dqn = dqn_config.build()

for i in range(2):
    result_dqn = dqn.train()
    print(pretty_print(result_dqn))
    if (result_dqn["episode_reward_mean"] > 48):
        print("ّFinish Training")
        quit(0)

I’ve done it a little different compared to you, but if you are trying to use the model afterwards to compute_actions it is expecting something like the input_dict that gets passed. So I walked around the problem by just making it a dictionary with a ‘default’ key. Hope this helps!

edits: tried to fix bold highlighting and gave up lol

env = FlattenWaypointEnv(gym.make(id='PyFlyt/QuadX-Waypoints-v1', flight_mode=-1), context_length=1)

obs_list = []
obs, info = env.reset()
# env.env.env.env.env.drones[0].set_mode(-1)
targets = env.unwrapped.waypoints.targets
points = np.concatenate((obs[10:13].reshape(-1,3), targets))
obs = {'default': obs}
obs_list += [obs]

reward_list = []
action_list = []
start = time.time()
for i in range(10*40):
    compute_action = algo.compute_actions(obs)
    action = compute_action['default']
    obs, reward, terminated, truncated, info = env.step(action)

    obs = {'default': obs}
    
    obs_list += [obs]
    
    reward_list += [reward]
    action_list += [action]
    
    if terminated or info['num_targets_reached'] == 4:
        break

arrays = [d['default'] for d in obs_list]
obs_array = np.vstack(arrays)
reward_array = np.array(reward_list)
action_array = np.array(action_list) 
env.close()