How severe does this issue affect your experience of using Ray?
- High: It blocks me to complete my task.
Ray: 2.7.2
PyG: 2.3.0
I have a custom environment that provides a graph as an observation given as input to a custom policy model which uses PyG GCNConv model to process the input graph followed by passing the GCNConv output to a sequential model to generate the action logits and value function output.
Policy model processes the input graphs in three steps:
- Reconstruct graph node and edge info from flattened observation
- Create a PyG Batch object from a list of Data objects created for each graph in the input
- Pass the Batch object as input to the GCNConv layer whose output is passed to the sequential network to generate action logits.
Please see the below sample Policy Model code for reference:
import torch
import torch.nn as nn
from torch_geometric.nn import GCNConv
from torch_geometric.data import Data, Batch
from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import PPOTorchRLModule
from ray.rllib.models.torch.misc import SlimFC, normc_initializer
class GNNModule(PPOTorchRLModule):
def __init__(self, config):
super().__init__(config)
self.ggnn = GCNConv(in_channels=100, out_channels=100, num_layers=2, norm='BatchNorm', dropout=0.2)
self.policy_head = nn.Sequential(
nn.Linear(100, 64),
nn.ReLU(),
nn.Linear(64, 2),
)
self._value_branch = SlimFC(
in_size=2,
out_size=1,
initializer=normc_initializer(0.01),
activation_fn=None)
def _forward_inference(self, batch):
with torch.no_grad():
return self._common_forward(batch)
def _forward_exploration(self, batch):
with torch.no_grad():
return self._common_forward(batch)
def _forward_train(self, batch):
return self._common_forward(batch)
def _common_forward(self, batch):
max_node_number = 400
state_size = 100
max_edge_count = 1000
obs = batch["obs"]
batch_size = obs.shape[0]
cur_idx = 0
# Setp 1: Reconstructing edge and node info from flattened observation
gnn_edge_index = (obs[:, cur_idx: cur_idx + max_edge_count*2]).reshape(batch_size, -1, 2)
cur_idx += max_edge_count*2
gnn_input_x = (obs[:, cur_idx : (cur_idx + max_node_number*state_size)]).reshape(batch_size, -1, state_size)
# Setp 2: Constructing a batch with size equal to number of graphs coming input to the model
data_list = []
data_list = [Data(x=gnn_input_x[i], edge_index=gnn_edge_index[i].long()) for i in range(batch_size)]
batch_data = Batch.from_data_list(data_list=data_list)
# Setp 3: Calling GCN model with batched data created
x = torch.tanh(self.ggnn(batch_data.x, batch_data.edge_index).reshape(batch_size, -1, 100))
action_logits = self.policy_head(x)
vf_value = self._value_branch(action_logits).squeeze(1)
return {"action_dist_inputs": action_logits, "vf_preds": vf_value}
I have the following queries:
- How can we encode a graph in an observation, currently I pass it as gym Dict space where I create two Box spaces, one for node and the other for edge details, I need to reconstruct the nodes and edges structure back in step 1 because it gets flattened during preprocessing, Is there a better way to encode graph as observation?
- In Step 2, I have to write a for loop to create a PyG Batch object from graphs coming as input, I believe a for loop will serialize the code and hence make the model slow, Is there some way to omit the step 2 in the forward function?