Hi everyone,
I am trying to pass additional information from custom_model to my environment. Here is my CustomModel:
class GNNPPO(TorchModelV2, nn.Module):
def __init__(self, obs_space: gym.spaces.Space, action_space: gym.spaces.Space, num_outputs: int,
model_config: ModelConfigDict, name):
TorchModelV2.__init__(
self, obs_space, action_space, num_outputs, model_config, name
)
nn.Module.__init__(self)
...
self.convs = torch.nn.ModuleList()
self.convs.append(SAGEConv(in_channels, hidden_channels))
for _ in range(num_gnn_layers - 1):
self.convs.append(SAGEConv(hidden_channels, hidden_channels))
self.pool = TopKPooling(in_channels=hidden_channels, ratio=self.top_k)
self.fc1 = torch.nn.Linear(self.top_k*hidden_channels, hidden_channels)
self.action_head_top_k_nodes = torch.nn.Linear(hidden_channels, self.top_k)
self.action_head_shift_op = torch.nn.Linear(hidden_channels,
model_config["custom_model_config"]["num_shift_operations"])
self._critic_head = nn.Sequential(nn.Linear(hidden_channels, 1))
def forward(
self,
input_dict: Dict[str, TensorType],
state: List[TensorType],
seq_lens: TensorType,
) -> (TensorType, List[TensorType]):
schedules = input_dict["obs"].values
graph_batch = schedule_to_graph(schedules, norm_processing_time=self.norm_processing_time)
x, edge_index, batch = graph_batch.x, graph_batch.edge_index, graph_batch.batch
for conv in self.convs:
x = conv(x, edge_index)
x = nn.functional.relu(x)
x = nn.functional.dropout(x, p=self.dropout, training=self.training)
# Now use TopKPooling to get the top k nodes
x, edge_index, _, batch, perm, score = self.pool(x, edge_index, None, batch, None)
self.top_k_nodes = perm
x = x.view(-1, self.top_k * x.size(1))
x = self.fc1(x)
x = nn.functional.relu(x)
x = nn.functional.dropout(x, p=self.dropout, training=self.training)
logits = [self.action_head_top_k_nodes(x), self.action_head_shift_op(x)]
outs = torch.cat(logits, dim=1)
self._value = self._critic_head(x)
return outs, state
I use a GNN. In this GNN, there is a so called TopKPooling-Layer, which is passing the “most important” nodes of the GNN to the next layers. The layer self.action_head_top_k_nodes
should produce logits corresponding to the importance of the selected nodes. But, in my environment, I need to know which are the corresponding nodes. E.g. the TopKPooling-Layer selected the nodes [5, 14, 11, 8, 7] and the softmax over the logits fromself.action_head_top_k_nodes
gave us 2. Then my environment needs to know, that node [5, 14, 11, 8, 7][2]=11 was selected. How can I pass this information to the environment?
Please, I need your help!
Thanks in advance