okay, somehow I’ve managed this myself. The second approach should look like this
for step in range(1000):
actions = {}
for agent_id, agent_obs in obs.items():
obs = agent_obs['observations']
agent_action_mask = agent_obs['action_mask']
obs_tensor = torch.tensor(obs, dtype=torch.float32).unsqueeze(0)
action_mask_tensor = torch.tensor(agent_action_mask, dtype=torch.float32)
with torch.no_grad():
action_out = rl_module.forward_inference({"obs": {'observations': obs_tensor,
'action_mask': action_mask_tensor}})
logits = action_out["action_dist_inputs"]
dist = Categorical(logits=logits)
action = dist.sample().item()
actions[agent_id] = action
obs, rewards, terminateds, truncateds, infos = env.step(actions)
all_rewards.append(rewards)
print(f"[STEP {step}] Rewards: {rewards}, Action: {actions}")
sum_rewards += rewards['gruz_1']
if terminateds.get("__all__", True) or truncateds.get("__all__", True):
# print(all_rewards)
print(f'Sum_rewords={sum_rewards}')
break
and it really works for me
PS still not sure if it works correctly. The “not crashing with errors” not always means that everything is perfect tbh