How severe does this issue affect your experience of using Ray?
- High: It blocks me to complete my task.
Hello,
I’ve recently started working with RLlib after moving from Stable Baselines 3 (SB3) because of its great multi-agent support and customizability. I’m implementing a custom multi-agent environment, and I’m running into issues when applying my own customized policy.
Environment Setup My environment follows a multi-agent dict format, and I define my action and observation spaces as follows:
single_agent_action_space = Box(low=-2, high=2, shape=(2,), dtype=np.float32)
single_agent_obs_space = Dict({
"current_pos": Box(low=-100, high=100, (2,), dtype=np.float32),
"target": Box(-100, high=100, (2,), dtype=np.float32),
})
# Multi-agent version
self.observation_space = Dict({
str(agent_id): single_agent_obs_space
for agent_id in self.possible_agents
})
self.action_space = Dict({
str(agent_id): single_agent_action_space
for agent_id in self.possible_agents
})
Everything works fine when I use RLlib’s default PPO algorithm.
Problem: Custom Policy & Expected Output Format I’m implementing a custom policy, and I extract the observations like this:
obs_dict = batch[Columns.OBS]
After performing feature extraction, I pass them through my policy_net and value_net, and return:
return {
Columns.ACTION_DIST_INPUTS: action_dist_inputs,
Columns.VF_PREDS: vf_preds
}
I’ve tried two different approaches, but both lead to errors:
- Returning as a Dictionary ({agent_id: tensor})
# Returns a dictionary with agent-specific action logits
action_dist_inputs = {agent_id: actions_logits[i] for i, agent_id in enumerate(obs_dict.keys())}
vf_preds = {agent_id: value_preds[i] for i, agent_id in enumerate(obs_dict.keys())}
Error:
AttributeError: 'dict' object has no attribute 'chunk'
- Returning as a Tensor (torch.stack(list(agent_dict.values())))
# Stack agent tensors into a single tensor
action_dist_inputs = torch.stack(list(actions_logits.values()))
vf_preds = torch.stack(list(value_preds.values()))
Error:
ValueError: The two structures don't have the same nested structure.
First structure: type=ndarray str=[-0.849055 0.6490481]
Second structure: type=dict str={'0': Box(-2.0, 2.0, (2,), float32), '65537': Box(-2.0, 2.0, (2,), float32)}
Question What is RLlib’s expected output format for _forward()
in a custom multi-agent policy? How should I properly structure the action distribution inputs for multiple agents? How does RLlib handle multiple agents’ outputs inside _forward_exploration()
and _forward_inference()
?
Any advice would be greatly appreciated! @sven1977 @