How severe does this issue affect your experience of using Ray?
- High: It blocks me to complete my task.
Hello, I am trying to train a GNN through RL in order to take actions in a graph environment.
My env returns as an observation a Graph that grows after every step.
print( type(obs), obs["x"].shape, obs["edge_index"].shape, obs["edge_attr"].shape )
<class 'dict'> torch.Size([7, 2]) torch.Size([6, 2]) torch.Size([6, 6])
I want to process this through a custom model in which I define my GNN Architecture with Pytorch Geometric, and train it with a regular RL Algorithm with continuous action space, such as PPO.
Here is my full Code dealing with Rllib:
torch, nn = try_import_torch()
parser = argparse.ArgumentParser()
parser.add_argument(
"--run", type=str, default="PPO", help="The RLlib-registered algorithm to use."
)
parser.add_argument(
"--framework",
choices=["tf", "tf2", "torch"],
default="torch",
help="The DL framework specifier.",
)
parser.add_argument(
"--as-test",
action="store_true",
help="Whether this script should be run as a test: --stop-reward must "
"be achieved within --stop-timesteps AND --stop-iters.",
)
parser.add_argument(
"--stop-iters", type=int, default=50, help="Number of iterations to train."
)
parser.add_argument(
"--stop-timesteps", type=int, default=100000, help="Number of timesteps to train."
)
parser.add_argument(
"--stop-reward", type=float, default=0.1, help="Reward at which we stop training."
)
parser.add_argument(
"--no-tune",
action="store_true",
help="Run without Tune using a manual train loop instead. In this case,"
"use PPO without grid search and no TensorBoard.",
)
parser.add_argument(
"--local-mode",
action="store_true",
help="Init Ray in local mode for easier debugging.",
)
parser.add_argument(
"--func_id", type=int, default=1, choices=range(1,16), help="Function from Cec2013 Benchmarks to optimize. From DMolina's Github")
if __name__ == "__main__":
args = parser.parse_args()
#print(f"Running with following CLI options: {args}")
env = GNN_Tersq
obs_space = gym.spaces.Dict({
'nodes': Repeated(Box(low=-np.inf, high=np.inf, shape= (3,)), max_len= 5000),
'edges': Repeated(Box(low=-np.inf, high=np.inf, shape=(2,)), max_len=10_000),
'edge_attr': Repeated(Box(low=-np.inf, high=np.inf, shape=(1,)), max_len=10_000)})
action_space = gym.spaces.Box(low=-1, high = 1, shape = (5,))
model = GraphNet(
in_dims=[3,2,1],
out_dims=[16,16,1],
obs_space= obs_space,
action_space= action_space,
model_config= MODEL_DEFAULTS,
name="Tersq-GN",
num_outputs = 5,
independent=False,
e2v_agg="sum",
n_hidden=1,
hidden_size=64,
activation=ReLU,
layer_norm=True,
)
ModelCatalog.register_custom_model("Graphnet", model)
#function to optimize
func_id = vars(args)["func_id"]
env_config={"func": func_id,
"generation_size":200,
"log_name":"tersq",
"observed_generations":200
}
def env_creator(env_config):
return GNN_Tersq(env_config)
register_env('GNN_Tersq',env_creator)
ray.init()
config = {
"env": "GNN_Tersq" ,
"framework":"torch",
"num_rollout_workers":1,
"action_space": action_space
"model":{
"custom_model": "GraphNet",
"vf_share_layers": True,
},
"env_config": env_config
}
algo = ppo.PPO(env="GNN_Tersq",config=config)
for _ in range(5):
algo.train()
ray.shutdown()
I get the error:
ERROR actor_manager.py:487 -- Ray error, taking actor 1 out of service. The actor died because of an error raised in its creation task, ray::RolloutWorker.__init__() (pid=27950, ip=172.17.0.2, repr=<ray.rllib.evaluation.rollout_worker.RolloutWorker object at 0x7fcd77173c90>)
ValueError: The two structures don't have the same nested structure.
First structure: type=NoneType str=None
Second structure: type=OrderedDict str=OrderedDict([('edge_attr', [array([ 0.05828598, 1.6800339 , 0.2460808 , 0.909437 , 1.1108482 ,
-1.2176332 ], dtype=float32), array([-0.8209033 , 1.5708834 , -1.7356552 , -0.7804611 , -0.5665525 ,
-0.76005054], dtype=float32), array([-0.10791866, 0.25782433, 0.7622504 , -1.6059427 , 0.11802711,
-1.5304987 ], dtype=float32), array([ 0.11877615, -0.00338906, -0.56129813, -1.1324824 , 1.5028561 ,
-0.05415997], dtype=float32), array([ 0.84529656, 0.05346558, 1.0095785 , -0.3181351 , 1.2736853 ,
-0.1735463 ], dtype=float32), array([ 2.1919084 , 1.6965641 , 0.2610673 , -2.080992 , 0.2640835 ,
-0.95981544], dtype=float32), array([-0.8087102, 1.678714 , 1.240176 , 1.9777763, -1.6990196,
0.9117386], dtype=float32), array([-0.07565449, -0.32657084, -0.82357234, 0.26498142, 1.0119815 ,
0.9550965 ], dtype=float32), array([ 1.0438224 , -0.13161731, 0.5500896 , 1.947861 , -0.9712911 ,
-0.25422993], dtype=float32), array([ 1.05997 , -0.9333079 , -0.95672804, -3.385738 , 0.32287568,
-0.5167504 ], dtype=float32), array([-0.5476242 , -0.24380413, 0.93366593, 0.10997711, -0.655031 ,
-1.8571781 ], dtype=float32), array([-2.0268688 , 2.25356 , 0.12142886, -0.7362118 , 0.44789505,
-0.29472384], dtype=float32), array([-0.9474914 , 0.22024356, -0.46483916, -0.7197042 , -1.7550093 ,
0.8779702 ], dtype=float32), array([ 0.75778973, 0.1403252 , 0.601641 , -0.83263403, -0.40459803,
0.39522386], dtype=float32), array([-0.16012946, -0.81058335, 1.497457 , 1.525054 , 1.2396181 ,
1.9565277 ], dtype=float32), array([ 0.01291324, 1.0813155 , -0.97064584, 0.38611206, 0.9009284 ,
-0.21593454], dtype=float32), array([ 0.552598 , -1.1890013 , -0.78714156, 1.67644 , -1.1925207 ,
-0.01577877], dtype=float32), array([-1.8247021 , -0.03998495, -0.9765081 , 1.6951705 , -0.5382523 ,
0.1107698 ], dtype=float32), array([ 1.8838704 , 0.73540753, 1.7623731 , 1.4731652 , -0.80677354,
1.3925757 ], dtype=float32), array([-0.16919075, 0.37236318, -0.79319936, 0.7247965 , -0.43676952,
0.23095421], dtype=float32), array([ 1.189384 , -0.33604422, -0.23432913, 0.91115797, -0.69075096,
0.4003402 ], dtype=float32), array([-0.28841266, -0.60751665, 0.5960525 , 2.182002 , 0.34697813,
0.10697586], dtype=float32), array([-0.59478056, 0.38831505, 0.31998768, -0.40115565, -0.4756637 ,
-0.681964 ], dtype=float32), array([-0.32288793, 0.69854254, -0.92184806, 0.35502335, 0.4157108 ,
-1.0448637 ], dtype=float32), array([-1.1380184 , -1.2641369 , -0.874923 , 0.32890946, 1.5758855 ,
-0.15539095], dtype=float32), array([ 0.8846484 , -1.366705 , 0.46513095, 0.25799438, 0.8809303 ,
1.3323174 ], dtype=float32)]), ('edges', [array([0, 0], dtype=int32), array([-2, -2], dtype=int32), array([-3, 0], dtype=int32), array([ 1, -2], dtype=int32), array([-1, -2], dtype=int32), array([-1, 0], dtype=int32), array([ 1, -2], dtype=int32), array([-1, -1], dtype=int32), array([-2, 0], dtype=int32), array([-2, -1], dtype=int32), array([ 0, -1], dtype=int32), array([-1, -2], dtype=int32), array([0, 1], dtype=int32), array([ 1, -1], dtype=int32), array([-2, 0], dtype=int32), array([ 0, -2], dtype=int32), array([0, 1], dtype=int32), array([ 1, -1], dtype=int32), array([ 1, -1], dtype=int32), array([-1, -1], dtype=int32), array([-1, -1], dtype=int32), array([ 0, -3], dtype=int32), array([-1, -1], dtype=int32), array([ 1, -2], dtype=int32), array([ 0, -1], dtype=int32), array([-1, 0], dtype=int32), array([ 0, -1], dtype=int32), array([1, 1], dtype=int32), array([-2, -1], dtype=int32), array([0, 0], dtype=int32), array([-1, 0], dtype=int32), array([-1, 1], dtype=int32), array([0, 0], dtype=int32), array([-2, 0], dtype=int32), array([-2, -1], dtype=int32), array([2, 0], dtype=int32), array([ 1, -1], dtype=int32), array([0, 0], dtype=int32), array([-1, 1], dtype=int32), array([-1, -2], dtype=int32), array([0, 0], dtype=int32), array([-1, 0], dtype=int32), array([-2, 0], dtype=int32), array([1, 0], dtype=int32), array([ 0, -1], dtype=int32), array([-1, -1], dtype=int32), array([ 0, -1], dtype=int32), array([-2, -1], dtype=int32), array([0, 0], dtype=int32), array([ 0, -1], dtype=int32), array([-1, -2], dtype=int32), array([0, 0], dtype=int32), array([0, 1], dtype=int32), array([ 1, -1], dtype=int32), array([ 0, -1], dtype=int32), array([ 2, -1], dtype=int32), array([3, 1], dtype=int32), array([-2, -2], dtype=int32), array([-1, -1], dtype=int32), array([-1, -1], dtype=int32), array([-3, -2], dtype=int32), array([-1, 1], dtype=int32), array([-1, -2], dtype=int32), array([ 0, -2], dtype=int32), array([-1, 0], dtype=int32), array([-3, 1], dtype=int32), array([-1, -1], dtype=int32)]), ('global_attr', array([1.4968876], dtype=float32)), ('nodes', [array([ 0.01219574, -0.31490242, -0.02738303], dtype=float32), array([-0.8618865, -1.624338 , 1.5764035], dtype=float32), array([ 0.34143075, 0.00225972, -0.68763596], dtype=float32), array([ 0.02351123, 0.24912626, -0.69703215], dtype=float32), array([ 0.17028035, 0.9453144 , -0.7214167 ], dtype=float32), array([-0.013497 , 1.4370594, 1.878117 ], dtype=float32), array([1.3843715, 0.5024601, 1.3597839], dtype=float32), array([-0.97689015, -0.03611429, -1.4003345 ], dtype=float32), array([ 1.4343497 , 0.824217 , -0.80190295], dtype=float32), array([0.44231763, 0.90646625, 0.15509629], dtype=float32), array([-0.53610784, 0.18710728, -0.29608127], dtype=float32), array([-0.90295374, 1.4328693 , -0.33721954], dtype=float32)])])
I have looked for this error message in the github but have not been able to find it.
I have googled my problem and I am certain it is not a problem of my Torch Tensors being wrongly dtyped…
Any clue where these structures come from and what they mean?
Any help is greatly appreciated.