Passing additional action information from custom_model to environment

Hi everyone,

I am trying to pass additional information from custom_model to my environment. Here is my CustomModel:

class GNNPPO(TorchModelV2, nn.Module):
    def __init__(self, obs_space: gym.spaces.Space, action_space: gym.spaces.Space, num_outputs: int,
                 model_config: ModelConfigDict, name):
        TorchModelV2.__init__(
            self, obs_space, action_space, num_outputs, model_config, name
        )
        nn.Module.__init__(self)

    ...

        self.convs = torch.nn.ModuleList()
        self.convs.append(SAGEConv(in_channels, hidden_channels))
        for _ in range(num_gnn_layers - 1):
            self.convs.append(SAGEConv(hidden_channels, hidden_channels))

        self.pool = TopKPooling(in_channels=hidden_channels, ratio=self.top_k)
        self.fc1 = torch.nn.Linear(self.top_k*hidden_channels, hidden_channels)

        self.action_head_top_k_nodes = torch.nn.Linear(hidden_channels, self.top_k)
        self.action_head_shift_op = torch.nn.Linear(hidden_channels,
                                                    model_config["custom_model_config"]["num_shift_operations"])
        self._critic_head = nn.Sequential(nn.Linear(hidden_channels, 1))

    def forward(
            self,
            input_dict: Dict[str, TensorType],
            state: List[TensorType],
            seq_lens: TensorType,
    ) -> (TensorType, List[TensorType]):
        schedules = input_dict["obs"].values
        graph_batch = schedule_to_graph(schedules, norm_processing_time=self.norm_processing_time)
        x, edge_index, batch = graph_batch.x, graph_batch.edge_index, graph_batch.batch

        for conv in self.convs:
            x = conv(x, edge_index)
            x = nn.functional.relu(x) 
            x = nn.functional.dropout(x, p=self.dropout, training=self.training) 

        # Now use TopKPooling to get the top k nodes
        x, edge_index, _, batch, perm, score = self.pool(x, edge_index, None, batch, None)
        self.top_k_nodes = perm
        x = x.view(-1, self.top_k * x.size(1))
        x = self.fc1(x)
        x = nn.functional.relu(x)
        x = nn.functional.dropout(x, p=self.dropout, training=self.training)

        logits = [self.action_head_top_k_nodes(x), self.action_head_shift_op(x)]
        outs = torch.cat(logits, dim=1)
        self._value = self._critic_head(x)
        return outs, state

I use a GNN. In this GNN, there is a so called TopKPooling-Layer, which is passing the “most important” nodes of the GNN to the next layers. The layer self.action_head_top_k_nodes should produce logits corresponding to the importance of the selected nodes. But, in my environment, I need to know which are the corresponding nodes. E.g. the TopKPooling-Layer selected the nodes [5, 14, 11, 8, 7] and the softmax over the logits fromself.action_head_top_k_nodes gave us 2. Then my environment needs to know, that node [5, 14, 11, 8, 7][2]=11 was selected. How can I pass this information to the environment?

Please, I need your help!
Thanks in advance

I don’t know enough to give you a confident answer yet. However, it seems like you want to pass additional information from your model to your environment. This is not a typical use case as usually the model receives information from the environment and not the other way around.
In your case, you want to pass the selected node information from your model to your environment. One way to do this might be to include this information in the action that your model outputs. However, this would require modifying your environment to accept and process this additional information, which might not be ideal.

Thanks for your answer. You are correct, I want to pass additional information from my model to the environment. Modifying the environment wouldn’t be a problem.
You mentioned the possibility to include the additional information in the action. But then the Action Distribution class would try to sample an action from that additional information (selected nodes). But I don’t want it to be sampled. That’s why I don’t think this solution would work. Do you have any other ideas?