Custom RLModule Observation Tensor Random Sorting

1. Severity of the issue: (select one)
None: I’m just curious or want clarification.
Low: Annoying but doesn’t hinder my work.
Medium: Significantly affects my productivity but can find a workaround.
High: Completely blocks me.

2. Environment:

  • Ray version: 2.46.0
  • Python version: 3.10
  • OS: Windows 10
  • Cloud/Infrastructure: /
  • Other libs/tools (if relevant): PettingZoo

3. What happened vs. what you expected:
I have a Parallel PettingZoo Env that handles all my agents. The Observations are stored in an OrderedDict where the Keys are the agent IDs and the corresponding value is the current observation. Each agent’s observation has this space:
gymnasium.spaces.Box(low=0.0, high=1.0, shape=(39,), dtype=np.float32)
I am using a custom RLModule, where I am trying to train a Torch Geometric Model (GNN). For doing the training correctly it is important to me that I can identify which observation belongs to which node / agent so that my edge_indices fit. However, the batch that is passed to my _forward_train no longer contains any association to my original agent IDs. It is also not sorted in the same way that the original observations were and seems rather random with each run. The Batch has this structure:

{'obs': tensor([[ 2.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
(MultiAgentEnvRunner pid=64752)           0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  1.,  0.,  0.,
(MultiAgentEnvRunner pid=64752)           0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
(MultiAgentEnvRunner pid=64752)         [10.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
(MultiAgentEnvRunner pid=64752)           0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  1.,  0.,  0.,
(MultiAgentEnvRunner pid=64752)           0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
(MultiAgentEnvRunner pid=64752)         [ 1.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
(MultiAgentEnvRunner pid=64752)           0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  1.,  0.,  0.,
(MultiAgentEnvRunner pid=64752)           0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
(MultiAgentEnvRunner pid=64752)         [15.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
(MultiAgentEnvRunner pid=64752)           0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  1.,  0.,  0.,
(MultiAgentEnvRunner pid=64752)           0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],

I have figured out a workaround to where I add the AgentID as first element in the observations (In this case you can see 2, 10, 1, 15, which are all AgentIDs). I then sort all the observations and remove this identifiert, but I am unsure if this is enough or if I am breaking other things internally. I guess there is a better way to go about things and I am missing something.
Here is my RLModule Code:

class TrafficGNNModule(TorchRLModule, ValueFunctionAPI):
    def setup(self):
        super().setup()
        # Register static graph as buffer so it moves with module devices
        edge_idx = torch.tensor(self.model_config["edge_index"], dtype=torch.long)
        self.register_buffer("edge_index", edge_idx, persistent=False)

        in_dim = self.observation_space.shape[0]
        hidden_dim = self.model_config.get("hidden_dim", 64)
        num_actions = self.action_space.n

        # GNN layers
        self.conv1 = GCNConv(in_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)

        # Policy head
        self.action_head = torch.nn.Linear(hidden_dim, num_actions)
        # Value function head
        self.vf_head = torch.nn.Linear(hidden_dim, 1)

    def _sort_and_strip_agent_ids(self, obs_batch: torch.Tensor) -> torch.Tensor:
        """
        Sortiert einen Batch Beobachtungen aufsteigend nach der ersten Spalte (Agent-ID)
        und entfernt diese ID anschließend aus dem Tensor.

        Args:
            obs_batch (torch.Tensor): Tensor der Form [B, F], wobei obs_batch[:, 0] die Agent-ID enthält.

        Returns:
            torch.Tensor: Sortierter Tensor ohne Agent-ID-Spalte, Form [B, F-1].
        """
        # Sortierreihenfolge nach Agent-ID
        sort_index = obs_batch[:, 0].long().argsort()
        # Tensor sortieren
        sorted_obs = obs_batch[sort_index]
        # Erste Spalte (Agent-ID) entfernen
        return sorted_obs[:, 1:]

    def _forward_train(self, batch, **kwargs):
        x = self._sort_and_strip_agent_ids(batch[Columns.OBS])
        print(f"BATCH OBS: {batch}")
        # GNN forward
        h = torch.relu(self.conv1(x, self.edge_index))
        h = torch.relu(self.conv2(h, self.edge_index))
        # Policy logits
        logits = self.action_head(h)
        # Value function predictions
        vf_preds = self.vf_head(h).squeeze(-1)
        return {
            Columns.ACTION_DIST_INPUTS: logits,
            Columns.VF_PREDS: vf_preds,
        }

    def _forward_exploration(self, batch, **kwargs):
        with torch.no_grad():
            return self._forward_train(batch, **kwargs)

    def _forward_inference(self, batch, **kwargs):
        with torch.no_grad():
            return self._forward_train(batch, **kwargs)

    def compute_values(self, batch, embeddings=None) -> torch.Tensor:
        # ValueFunctionAPI: produce value predictions for advantages
        x = self._sort_and_strip_agent_ids(batch[Columns.OBS])

        h = torch.relu(self.conv1(x, self.edge_index))
        h = torch.relu(self.conv2(h, self.edge_index))
        return self.vf_head(h).squeeze(-1)

Please note I am a beginner, so don’t be too harsh / further help or detailed explanations would be very much appreciated! Thank you! :slight_smile:

  • Expected: Sorted Batch Tensor
  • Actual: Random Order Batch Tensor