How to define a graph type observation space and use torch_geometric?

Hi folks,

I am using Rllib to train an agent for my custom env whose observation_space needs to be a graph type data, similar to gData type from torch_geometric library.

I implemented a dummy env and a custom model, but it does not work. I wonder if anyone can have a look and let me how can I change it to make it work. Here is a runnable code:

# %% 
import os
import numpy as np
import gymnasium as gym
from collections import OrderedDict

import torch
from torch import nn

import ray
from ray import tune, air
from ray.rllib.utils import check_env
from ray.tune.logger import pretty_print
from ray.rllib.models import ModelCatalog
from ray.tune.registry import get_trainable_cls
from ray.rllib.utils.spaces.repeated import Repeated
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2

from torch_geometric.nn import GCNConv
from torch_geometric.nn import global_mean_pool
from import Data as gData
from torch_geometric.loader import DataLoader as gDataLoader

from callbacks import MyCallbacks

class SimpleGnnEnv(gym.Env):
    def __init__(self, env_config={'env_name': 'SimpleGnnEnv'}):
        self.env_name = env_config['env_name']
        self.n_actions = 3
        self.n_states = 5
        self.action_space = gym.spaces.Discrete(self.n_actions)
        self.num_nodes = 5
        self.num_edges = 3
        self.max_edges = 9
        self.node_feature_dim = 7
        edge = gym.spaces.Box(low=0,
        self.observation_space = gym.spaces.Dict({
            'gnn_nodes': gym.spaces.Box(low=-5, 
                                        shape=(self.num_nodes, self.node_feature_dim),
            'gnn_edges': Repeated(edge, self.max_edges),

    def reset(self, *, seed=None, options=None):
        observation = OrderedDict({
            'gnn_nodes': np.random.rand(self.num_nodes,self.node_feature_dim),
            'gnn_edges': np.random.choice(self.num_nodes, (self.num_edges, 2))
        observation = self.observation_space.sample()
        self.timestep = 0
        return observation, {}

    def _update_obs(self, action):
        observation = OrderedDict({
            'gnn_nodes': np.random.rand(self.num_nodes,self.node_feature_dim),
            'gnn_edges': np.random.choice(self.num_nodes, (self.num_edges, 2))
        observation = self.observation_space.sample()
        return observation
    def _execute_action(self, action):
        next_observation = self._update_obs(action)
        done = False if self.timestep <=3 else True
        reward = 1 if done else 0
        return next_observation, reward, done

    def _get_info(self, done):
        random_info_dict = {'random_info': done} #np.random.randn()
        info = {'agent_11': random_info_dict, 'timestep': self.timestep, 'done': done}
        return info
    def step(self, action):
        self.timestep += 1
        self.observation, self.reward, self.done = self._execute_action(action)
        self.truncated = self.done = self._get_info(self.done) if self.done else {}
        return self.observation, self.reward, self.done, self.truncated,     

    def seed(self, seed: int = None):
        self.np_random, seed = gym.utils.seeding.np_random(seed)
        return [seed]

class TinyGnnNet(TorchModelV2, nn.Module):
    def __init__(self, obs_space, action_space, num_outputs, model_config, name):
            self, obs_space, action_space, num_outputs, model_config, name
        self.obs_space = obs_space
        self.action_space = action_space
        self.num_outputs = num_outputs
        self.model_config = model_config = name
        self.orig_space = getattr(obs_space, "original_space", obs_space)
        self.p = 0.0
        self.hidden_dim = 256
        self.actor_hidden_dim = 256
        self.critic_hidden_dim = 256
        self.node_feature_dim = 7
        self.feature_net_inp_layer = GCNConv(self.node_feature_dim, 2*self.hidden_dim)
        self.feature_net_1 = nn.Sequential(
            GCNConv(2*self.hidden_dim, self.hidden_dim),
            nn.Linear(self.hidden_dim, self.hidden_dim),
        self.feature_net_out_layer = nn.Sequential(
            nn.Linear(self.hidden_dim, self.actor_hidden_dim),
        self._actor_head = nn.Sequential(
            nn.Linear(self.actor_hidden_dim, self.num_outputs)) 
        self._critic_head = nn.Sequential(
            nn.Linear(self.critic_hidden_dim, 1))

    def forward(self, input_dict, state, seq_lens):
        obs = input_dict["obs"]
        gdata = self._convert_graph_data_dict_to_gdata_tensor(obs)
        x, edge_index = gdata.x, gdata.edge_index
        x = self.feature_net_inp_layer(x, edge_index)
        x = self.feature_net_1(x, edge_index)
        x = global_mean_pool(x, gdata.batch) # to get graph embeddings
        x = self.feature_net_out_layer(x)
        actor_latent = self._actor_head(x)
        value = self._critic_head(x)
        self._value = value.reshape(-1)
        return actor_latent, state
    def value_function(self):
        return self._value#.flatten

    def _convert_graph_data_dict_to_gdata_tensor(self, real_obs):
        gnn_nodes = real_obs['gnn_nodes']
        gnn_edges = real_obs['gnn_edges']
        print(f"gnn_nodes: {gnn_nodes}")
        print(f"gnn_edges: {gnn_edges}")
        edge_index = torch.tensor(gnn_edges, dtype=torch.long)
        x = torch.tensor(gnn_nodes, dtype=torch.float)
        y = torch.tensor([], dtype=torch.float) 
        gdata = gData(edge_index=edge_index, x=x, y=y)
        return gdata
# %% Main
if __name__ == "__main__":
    env_name = 'SimpleGnnEnv'
    agent_name = 'PPO' # alpha_zero
    learner_name = 'trainer' # trainer tunner  random
    num_iters = 1
    num_rollout_workers = 1
    env_config = {'env_name': env_name}

    store_falg = False
    save_env_data_flag = False
    save_agent_flag = False
    load_agent_flag = False
    num_gpus = 0
    local_mode_flag = True
    current_dir = os.getcwd()
    storage = os.path.join(current_dir, 'storage')
    env_data_dir = os.path.join(storage, 'env_dict')
    save_chkpt_to_dir = os.path.join(storage, 'models')
    load_chkpt_from_dir = os.path.join(storage, agent_name)
    if learner_name == 'random':
        env = SimpleGnnEnv(env_config=env_config)
        s, _ = env.reset()
        a = env.action_space.sample()
        s, _, _, _, _ = env.step(a)
        print('Checking the env ...')
        while True:
            action = env.action_space.sample()
            obs, rew, done, truncated, info = env.step(action)
            if done:
                print(f'info: {info}')
                print("not done yet")
                print(f'info: {info}')
        ModelCatalog.register_custom_model('TinyGnnNet', TinyGnnNet)  
        algo_cls = get_trainable_cls(agent_name)
        param_space = (
                .environment(SimpleGnnEnv, env_config=env_config)
                .training(model={"custom_model": 'TinyGnnNet',
                                 "vf_share_layers": True},
                          # train_batch_size=4,
                          # sgd_minibatch_size=2,

        if save_env_data_flag:
            param_space.callbacks_class = MyCallbacks
            param_space.output = env_data_dir
            param_space.output_max_file_size = 5000000
            param_space.output_config = {
                    "format": "json",  # json or parquet
                    # Directory to write data files.
                    "path": env_data_dir,
                    # Break samples into multiple files, each containing about this many records.
                    "max_num_samples_per_file": 100000,

        if learner_name == 'trainer':
            algo =
            if save_env_data_flag:
                algo.output = env_data_dir 
            if load_agent_flag:
                print("In trainer: The model loaded!")
            for n in range(num_iters):
                print(f"---------- in trainer: episode: {n}")
                result = algo.train()
                if save_agent_flag:
                    checkpoint_dir =
        elif learner_name == 'tunner':
            stop = {"training_iteration": num_iters}
            run_config = air.RunConfig(
            tuner = tune.Tuner(

            if load_agent_flag:
                tuner.restore(load_chkpt_from_dir, trainable="your_trainable")
            results =


Hi @deepgravity ,

Can you share a little bit about what is not working? The high-level idea of representing the graphs as a nested dict and constructing the graph structure in models’ code sounds reasonable to me.

Hi @kourosh, thank you for your reply.

Well, the problem is that the code runs without any errors, which is really strange because I am kind of sure that there should be some issues in the way I defined my observation_space and more specifically the way I feed the observation to torch_geometric’s conv layers.

In my view, there should be a problem in this method: _convert_graph_data_dict_to_gdata_tensor

My observation dim is:

gnn_nodes = real_obs['gnn_nodes'].shape -> (num_nodes, node_features_dim) = (5, 7)

gnn_edges = real_obs['gnn_edges'].shape -> (adaptive_size, 2) = (-, 2)

Also, Ray adds a batch_size dimension to them:

gnn_nodes = real_obs['gnn_nodes'].shape -> (batch_size, num_nodes, node_features_dim) = (32, 5, 7)

gnn_edges = real_obs['gnn_edges'].shape -> (batch_size, adaptive_size, 2) = (32, -, 2)

However, these data should be un-readable for torch-geometric’s gData method as it only can read data with 2 dim not 3 (2 + batch dim).

Also, torch-geometric has a special way of batching, which is different from the Ray batching method.

With all these issues existing in my code, my code runs fine which is really strange. And the problem is that although I set the local_mode to True, when I execute the code in the debug mode, Ray ignores all my breaking points. So, I cannot debug my code either!

Would you please run the code and see if you can run in the debug mode and check the errors I mentioned above?

I think this code ultimately could be very useful for all who want to use graph neural networks with Rllib.


Hi @deepgravity,

You might be able to get the breakpoints to work if you set the number of rollout workers to 0 and create_env_on_driver to True.

Hi @kourosh, thanks for your reply, but still cannot run the code in debug mode!

Hi @deepgravity,

On our latest releases of RLlib we have turned RLModule/Learner API on by default for PPO algorithm which would soft deprecate the ModelV2 customization stack. That’s why you could not stop at an internal breakpoint. Sorry for the confusion. We have to make a bold announcement somewhere on the doc pages and a code warning :slight_smile:
I filed an issue to track this: [RLlib] If RLModule is enabled by default user should get warnings if they are using old stack related components. · Issue #37085 · ray-project/ray · GitHub

To get to your breakpoint you need to turn the RLModule and Learner API off via the configuration.

config =

Hi Kourosh, thank you for your reply. I can now run the code in debug mode but still do not know how to use the graph data and graph observation space. could you please have a look at the code I sent before? Thanks!

Have you checked out the externally contributed GNN use-case in TF. It is mentioned as part of the rllib_contrib