I found this post, where someone creates the graph in the model init method:
I use the observation approach:
In the obs space you can use:
edges = Repeated(Box(low=0, high=N, shape=(2,), dtype=np.int64), max_len=max_edges)
obs_space = Dict({‘edges’:edges})
in the model use the unflattened observation
edges = input_dict[“obs”][‘edges’]
which is of type repeated_values:
Make sure your dimensions add up and don’t forget to cast the tensor if it’s needed. By default observations in the model are floating point numbers.
You can extend this for node embeddings
1 Like