Error: two structures are not the same. GNN Model

How severe does this issue affect your experience of using Ray?

  • High: It blocks me to complete my task.

Hello, I am trying to train a GNN through RL in order to take actions in a graph environment.
My env returns as an observation a Graph that grows after every step.

print( type(obs), obs["x"].shape, obs["edge_index"].shape, obs["edge_attr"].shape )
<class 'dict'> torch.Size([7, 2]) torch.Size([6, 2]) torch.Size([6, 6])

I want to process this through a custom model in which I define my GNN Architecture with Pytorch Geometric, and train it with a regular RL Algorithm with continuous action space, such as PPO.

Here is my full Code dealing with Rllib:

torch, nn = try_import_torch()


parser = argparse.ArgumentParser()
parser.add_argument(
    "--run", type=str, default="PPO", help="The RLlib-registered algorithm to use."
)
parser.add_argument(
    "--framework",
    choices=["tf", "tf2", "torch"],
    default="torch",
    help="The DL framework specifier.",
)
parser.add_argument(
    "--as-test",
    action="store_true",
    help="Whether this script should be run as a test: --stop-reward must "
    "be achieved within --stop-timesteps AND --stop-iters.",
)
parser.add_argument(
    "--stop-iters", type=int, default=50, help="Number of iterations to train."
)
parser.add_argument(
    "--stop-timesteps", type=int, default=100000, help="Number of timesteps to train."
)
parser.add_argument(
    "--stop-reward", type=float, default=0.1, help="Reward at which we stop training."
)
parser.add_argument(
    "--no-tune",
    action="store_true",
    help="Run without Tune using a manual train loop instead. In this case,"
    "use PPO without grid search and no TensorBoard.",
)
parser.add_argument(
    "--local-mode",
    action="store_true",
    help="Init Ray in local mode for easier debugging.",
)
parser.add_argument(
    "--func_id", type=int, default=1, choices=range(1,16), help="Function from Cec2013 Benchmarks to optimize. From DMolina's Github")





if __name__ == "__main__":
    args = parser.parse_args()
    
    #print(f"Running with following CLI options: {args}")

    
    
    env = GNN_Tersq
    obs_space = gym.spaces.Dict({
                'nodes': Repeated(Box(low=-np.inf, high=np.inf, shape= (3,)), max_len= 5000),
                'edges': Repeated(Box(low=-np.inf, high=np.inf, shape=(2,)), max_len=10_000),
                'edge_attr': Repeated(Box(low=-np.inf, high=np.inf, shape=(1,)), max_len=10_000)})
    action_space = gym.spaces.Box(low=-1, high = 1, shape = (5,))   
 
 
    model = GraphNet(
    in_dims=[3,2,1],
    out_dims=[16,16,1],
    obs_space= obs_space,
    action_space= action_space,
    model_config= MODEL_DEFAULTS,
    name="Tersq-GN",
    num_outputs = 5,
    independent=False,
    e2v_agg="sum",
    n_hidden=1,
    hidden_size=64,
    activation=ReLU,
    layer_norm=True,
    )

    ModelCatalog.register_custom_model("Graphnet", model) 

    #function to optimize
    
    func_id = vars(args)["func_id"]
  

    env_config={"func": func_id,  
                "generation_size":200, 
                "log_name":"tersq",
                "observed_generations":200
                }

    def env_creator(env_config):
        return GNN_Tersq(env_config)
    
    register_env('GNN_Tersq',env_creator)

    
    ray.init()

    config = {
  
        "env": "GNN_Tersq" ,
        "framework":"torch",
        "num_rollout_workers":1,
        "action_space": action_space
        "model":{
                "custom_model": "GraphNet",
                "vf_share_layers": True,
            },
        "env_config": env_config
    }
        
    algo = ppo.PPO(env="GNN_Tersq",config=config)
    for _ in range(5):
        algo.train()
                    
    ray.shutdown()

I get the error:

ERROR actor_manager.py:487 -- Ray error, taking actor 1 out of service. The actor died because of an error raised in its creation task, ray::RolloutWorker.__init__() (pid=27950, ip=172.17.0.2, repr=<ray.rllib.evaluation.rollout_worker.RolloutWorker object at 0x7fcd77173c90>)
ValueError: The two structures don't have the same nested structure.

First structure: type=NoneType str=None

Second structure: type=OrderedDict str=OrderedDict([('edge_attr', [array([ 0.05828598,  1.6800339 ,  0.2460808 ,  0.909437  ,  1.1108482 ,
       -1.2176332 ], dtype=float32), array([-0.8209033 ,  1.5708834 , -1.7356552 , -0.7804611 , -0.5665525 ,
       -0.76005054], dtype=float32), array([-0.10791866,  0.25782433,  0.7622504 , -1.6059427 ,  0.11802711,
       -1.5304987 ], dtype=float32), array([ 0.11877615, -0.00338906, -0.56129813, -1.1324824 ,  1.5028561 ,
       -0.05415997], dtype=float32), array([ 0.84529656,  0.05346558,  1.0095785 , -0.3181351 ,  1.2736853 ,
       -0.1735463 ], dtype=float32), array([ 2.1919084 ,  1.6965641 ,  0.2610673 , -2.080992  ,  0.2640835 ,
       -0.95981544], dtype=float32), array([-0.8087102,  1.678714 ,  1.240176 ,  1.9777763, -1.6990196,
        0.9117386], dtype=float32), array([-0.07565449, -0.32657084, -0.82357234,  0.26498142,  1.0119815 ,
        0.9550965 ], dtype=float32), array([ 1.0438224 , -0.13161731,  0.5500896 ,  1.947861  , -0.9712911 ,
       -0.25422993], dtype=float32), array([ 1.05997   , -0.9333079 , -0.95672804, -3.385738  ,  0.32287568,
       -0.5167504 ], dtype=float32), array([-0.5476242 , -0.24380413,  0.93366593,  0.10997711, -0.655031  ,
       -1.8571781 ], dtype=float32), array([-2.0268688 ,  2.25356   ,  0.12142886, -0.7362118 ,  0.44789505,
       -0.29472384], dtype=float32), array([-0.9474914 ,  0.22024356, -0.46483916, -0.7197042 , -1.7550093 ,
        0.8779702 ], dtype=float32), array([ 0.75778973,  0.1403252 ,  0.601641  , -0.83263403, -0.40459803,
        0.39522386], dtype=float32), array([-0.16012946, -0.81058335,  1.497457  ,  1.525054  ,  1.2396181 ,
        1.9565277 ], dtype=float32), array([ 0.01291324,  1.0813155 , -0.97064584,  0.38611206,  0.9009284 ,
       -0.21593454], dtype=float32), array([ 0.552598  , -1.1890013 , -0.78714156,  1.67644   , -1.1925207 ,
       -0.01577877], dtype=float32), array([-1.8247021 , -0.03998495, -0.9765081 ,  1.6951705 , -0.5382523 ,
        0.1107698 ], dtype=float32), array([ 1.8838704 ,  0.73540753,  1.7623731 ,  1.4731652 , -0.80677354,
        1.3925757 ], dtype=float32), array([-0.16919075,  0.37236318, -0.79319936,  0.7247965 , -0.43676952,
        0.23095421], dtype=float32), array([ 1.189384  , -0.33604422, -0.23432913,  0.91115797, -0.69075096,
        0.4003402 ], dtype=float32), array([-0.28841266, -0.60751665,  0.5960525 ,  2.182002  ,  0.34697813,
        0.10697586], dtype=float32), array([-0.59478056,  0.38831505,  0.31998768, -0.40115565, -0.4756637 ,
       -0.681964  ], dtype=float32), array([-0.32288793,  0.69854254, -0.92184806,  0.35502335,  0.4157108 ,
       -1.0448637 ], dtype=float32), array([-1.1380184 , -1.2641369 , -0.874923  ,  0.32890946,  1.5758855 ,
       -0.15539095], dtype=float32), array([ 0.8846484 , -1.366705  ,  0.46513095,  0.25799438,  0.8809303 ,
        1.3323174 ], dtype=float32)]), ('edges', [array([0, 0], dtype=int32), array([-2, -2], dtype=int32), array([-3,  0], dtype=int32), array([ 1, -2], dtype=int32), array([-1, -2], dtype=int32), array([-1,  0], dtype=int32), array([ 1, -2], dtype=int32), array([-1, -1], dtype=int32), array([-2,  0], dtype=int32), array([-2, -1], dtype=int32), array([ 0, -1], dtype=int32), array([-1, -2], dtype=int32), array([0, 1], dtype=int32), array([ 1, -1], dtype=int32), array([-2,  0], dtype=int32), array([ 0, -2], dtype=int32), array([0, 1], dtype=int32), array([ 1, -1], dtype=int32), array([ 1, -1], dtype=int32), array([-1, -1], dtype=int32), array([-1, -1], dtype=int32), array([ 0, -3], dtype=int32), array([-1, -1], dtype=int32), array([ 1, -2], dtype=int32), array([ 0, -1], dtype=int32), array([-1,  0], dtype=int32), array([ 0, -1], dtype=int32), array([1, 1], dtype=int32), array([-2, -1], dtype=int32), array([0, 0], dtype=int32), array([-1,  0], dtype=int32), array([-1,  1], dtype=int32), array([0, 0], dtype=int32), array([-2,  0], dtype=int32), array([-2, -1], dtype=int32), array([2, 0], dtype=int32), array([ 1, -1], dtype=int32), array([0, 0], dtype=int32), array([-1,  1], dtype=int32), array([-1, -2], dtype=int32), array([0, 0], dtype=int32), array([-1,  0], dtype=int32), array([-2,  0], dtype=int32), array([1, 0], dtype=int32), array([ 0, -1], dtype=int32), array([-1, -1], dtype=int32), array([ 0, -1], dtype=int32), array([-2, -1], dtype=int32), array([0, 0], dtype=int32), array([ 0, -1], dtype=int32), array([-1, -2], dtype=int32), array([0, 0], dtype=int32), array([0, 1], dtype=int32), array([ 1, -1], dtype=int32), array([ 0, -1], dtype=int32), array([ 2, -1], dtype=int32), array([3, 1], dtype=int32), array([-2, -2], dtype=int32), array([-1, -1], dtype=int32), array([-1, -1], dtype=int32), array([-3, -2], dtype=int32), array([-1,  1], dtype=int32), array([-1, -2], dtype=int32), array([ 0, -2], dtype=int32), array([-1,  0], dtype=int32), array([-3,  1], dtype=int32), array([-1, -1], dtype=int32)]), ('global_attr', array([1.4968876], dtype=float32)), ('nodes', [array([ 0.01219574, -0.31490242, -0.02738303], dtype=float32), array([-0.8618865, -1.624338 ,  1.5764035], dtype=float32), array([ 0.34143075,  0.00225972, -0.68763596], dtype=float32), array([ 0.02351123,  0.24912626, -0.69703215], dtype=float32), array([ 0.17028035,  0.9453144 , -0.7214167 ], dtype=float32), array([-0.013497 ,  1.4370594,  1.878117 ], dtype=float32), array([1.3843715, 0.5024601, 1.3597839], dtype=float32), array([-0.97689015, -0.03611429, -1.4003345 ], dtype=float32), array([ 1.4343497 ,  0.824217  , -0.80190295], dtype=float32), array([0.44231763, 0.90646625, 0.15509629], dtype=float32), array([-0.53610784,  0.18710728, -0.29608127], dtype=float32), array([-0.90295374,  1.4328693 , -0.33721954], dtype=float32)])])

I have looked for this error message in the github but have not been able to find it.
I have googled my problem and I am certain it is not a problem of my Torch Tensors being wrongly dtyped…
Any clue where these structures come from and what they mean?
Any help is greatly appreciated.

Hi @Maldades,

What you are looking at is a (re-raised) error. I’m pretty sure I know it from tensorflow, but maybe torch raises something super similar.
The error is raised in RolloutWorker, which is a Ray actor. To get more intel you can run your code with ray.init(local_mode)=True, maybe the error stack will tell you more.
More high level thoughts on this:

  • RLlib creates models and their ViewRequirements on Algorithm instantiation. This means that if your neural network’s shapes change over time, assumptions we make about these don’t hold anymore.
  • Therefore, if your observation space RLlib is very unlikely to work
  • Do you get the same error if your graph does not grow?

I’ll look into making changes that make these errors more informative.