Hello,
I’m trying to train a Graph Network on RLLib. For this purpose I’m using the Pytorch Geometric library.
This library has a specific way to collate samples together in a batch (because each sample has a graph of different size, more info here ).
I’m a bit lost on how to integrate Pytorch Geometric with RLLib.
Should I convert observation to tensor graph representation in my custom environment ? Or should I do it on model side ?
Alternatively, can I change the way samples are collated into batch ?
Hey @Astariul , have you taken a look at the Repeated
space that RLlib offers? It allows for variable sized observations. There is an example script here that shows how you can use it with an env that outputs variable sized observations:
"""Example of using variable-length Repeated / struct observation spaces.
This example shows:
- using a custom environment with Repeated / struct observations
- using a custom model to view the batched list observations
For PyTorch / TF eager mode, use the `--framework=[torch|tf2|tfe]` flag.
"""
import argparse
import os
import ray
from ray import tune
from ray.rllib.models import ModelCatalog
from ray.rllib.examples.env.simple_rpg import SimpleRPG
from ray.rllib.examples.models.simple_rpg_model import CustomTorchRPGModel, \
CustomTFRPGModel
parser = argparse.ArgumentParser()
This file has been truncated. show original
1 Like
smorad
August 9, 2021, 12:27pm
3
2 Likes