Efficient set and graph space for RL

Hi,

I’m working on RL problems that involve dynamically changing input spaces (sets and graphs).
I’m using the Repeated space, as that’s design for this purpose, i.e. to represent a graph (following the PyG convention):

        d = {
           "x": Repeated(
                gym.spaces.Box(-1, 1, shape=x_shape, dtype=np.float32),
                MAX_ELEMS,
            ),
            "edge_index": Repeated(
                gym.spaces.Box(0, MAX_ELEMS, shape=(2,), dtype=np.int64),
                MAX_ELEMS ** 2,
            ),
        }

        observation_space = gym.spaces.Dict(d)

It works fine in general. However, it creates a significant overhead, because the allocated tensors scale with x_shape*MAX_ELEMS + 2*MAX_ELEMS**2.

I understand that the flattened observations need to be fixed, but I never use that anywhere. This creates a large memory and latency overhead on the GPU, because the whole zero-padded tensor is transferred from CPU to GPU. This is especially bad when the number of elements follow a long tailed Poisson distribution, because I need to use a large MAX_ELEMS while most of the time there’s only a fraction of the allocated tensor utilized.

I created a graph space for this purpose, such as:

class GraphSpace(gym.Space):
    def __init__(self, x_shape, e_shape=None, dtype=np.float32):
        super().__init__()
        self.x_shape = x_shape
        self.e_shape = e_shape
        self._shape = (0,)
        self.dtype = dtype

    def sample(self):
        if self.e_shape:
            edge_attr = self.np_random.normal(size=self.e_shape)[np.newaxis].astype(
                self.dtype
            )
        else:
            edge_attr = None

        x = self.np_random.normal(size=self.x_shape)[np.newaxis].astype(self.dtype)

        return {
            "x": x,
            "edge_attr": edge_attr,
            "edge_index": np.array([[0], [0]]),
        }

    def contains(self, x):
        # placeholder
        return True

    def __repr__(self):
        return "Graph({}, {})".format(self.x_shape, self.e_shape)

It does work in principle, however, there are a few bits that need to be adjusted.
E.g. for batching RLlib assumes arrays with the same dimensionality, but in this case that’s not true, so instead of a normal Numpy array (size: batch x features) it creates an object array which breaks slicing downstream and Pytorch tensor conversion.

I haven’t finished the conversion because there are a lot of places that need to be changed, but to me it seems like there’s no fundamental issue in allowing variable input sizes, only syntactic.

Also, Pytorch Geometric has a very convenient batching mechanism to address the same problem, while making it possible to handle the input tensors as a regular 2d tensor with size BN x F (batch size * number of items x feature size). Integrating this approach would be a large endeavour, but it could enable a lot of applications that are currently limited by the current implementation.

So, is there any plan to enable more efficient and flexible dynamic spaces? I.e. pretty much the same functionality that Repeated has, but without the overhead of zero-padding.

this does sounds like it will break a lot of assumptions in RLlib. but there’s currently no plan to support additional dynamic spaces in RLlib.
I wonder if you can reuse some of the utils (loss fn, etc) from RLlib, and build a custom solution.

Yeah, I could work out a solution for a few of the assumptions, but honestly I don’t know how deep the issue goes. I would think that regular tensors are quite fundamental to the RLlib data pipelines unfortunately. On the other hand, the issue (at least for me) is on the model input side, so the loss function, action distributions, etc. are not affected. That would change though if there was a need for dynamic action spaces.

For the observations I think there’s a way to get around it, which involves something similar to RepeatedValues with a matching preprocessor (see RepeatedValuesPreprocessor), that would handle the batching and stacking. Maybe each observation could include a dummy scalar value for the obs_flat bit that is expected by some parts of the experience replay and alike, but I’m not sure if that would work.

Worst case scenario, the environment could put the observation directly into the object store and only pass the reference to the pipeline and the model would retrieve the objects based on their reference.

Something like:

# Environment
def step(action):
    ....
    object_ref = ray.put(obs)
    return object_ref, reward, done, info

# Policy
def forward(input_dict):
    obs = [ray.get(ref) for ref in input_dict["obs"]]

    ....
    # to tensor, to GPU, etc.

It feels quite hacky though.

1 Like

Did you ever figure out how to make this work by any chance?

No, I’m still using RepeatedValues, it’s a quite inefficient, but works for now.

Note that gym spaces now also have a dynamic graph space implementation. gym/graph.py at master · openai/gym (github.com)

1 Like

vakker00 I was really pleased to see someone going under same situation.

Like you’ve come up with, I tried to bypass this way(in the process hacking),
(see the code snippet)

but ray.get blocks all other remote operations.

until ray team supports flexible observation space “adaptor”
we might have to rather write to a file than exploiting ObjectStore.

Not sure if I’m missing sth.

Env Side

class ObjRefSpace(spaces.Box):
    def __init__(self):
        super().__init__(shape=(28,), low=0, high=255, dtype=np.uint8)

        obs = nx.read_gexf('qsort_multidigraph.gexf')
        obs = parse_graph(obs)
        obj_ref = ray.put(obs)
        arr = RllibWrapper.encrypt(obj_ref)

        self.dummy = arr

    def sample(self):
        return self.dummy

    def contains(self, x):
        return True

class RllibWrapper(Wrapper):
    def __init__(self, env: Env) -> None:
        if not ray.is_initialized():
            raise RuntimeError('Ray must be initialized before using RllibWrapper')

        super().__init__(env)
        # Repeated Values too inefficient
        # So we bypass by storing object into ObjectStore
        self.observation_space = ObjRefSpace()

        # self.flag = True

        # TODO: check if ray server is up

    @staticmethod
    def encrypt(obj_ref: ray.ObjectRef) -> np.ndarray:
        # obj_ref.binary() == b'\x00\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\x01\x00\x00\x00\x03'

        arr = np.array([i for i in obj_ref.binary()], dtype=np.uint8)

        return arr

    @staticmethod
    def decrypt(arr: np.ndarray) -> ObjectRef:
        # obj = ObjectRef(arr.tobytes())
        obj = ObjectRef(arr.tobytes())
            
        return obj 

    def _process_obs(self, obs):
        obs = parse_graph(obs)

        # e.g. ObjectRef(00ffffffffffffffffffffffffffffffffffffff0100000003000000)
        obj_ref = ray.put(obs)

        arr = RllibWrapper.encrypt(obj_ref)

        return arr

    def reset(self, **kwargs):
        obs = self.env.reset(**kwargs)

        return self._process_obs(obs)

    def step(self, action):
        obs, rew, done, info = self.env.step(action)
        obs = self._process_obs(obs)

        return obs, rew, done, info

Policy Side

    @override(TorchModelV2)
    def forward(self, input_dict, state, seq_lens):
        batch = []
        for encrypted_arr in input_dict['obs']:
            arr = encrypted_arr.detach().cpu().numpy()
            obj_ref = RllibWrapper.decrypt(arr)
            arr = ray.get(obj_ref)
            print(arr)

vakker00 Pleased to see someone going under same situation

Like you’ve come up with, I tried to bypass this way.

(See the code snippet below)

but as I’m using remote rollout workers, ray.get will somehow block other remote operations and hang forever

Until ray team supports flexible or dummy observation, maybe we could rather write observation to a file.

Just let me know if I missed sth.

Env Side

class ObjRefSpace(spaces.Box):
    def __init__(self):
        super().__init__(shape=(28,), low=0, high=255, dtype=np.uint8)

        obs = {some dummy}
        obs = parse_graph(obs)
        obj_ref = ray.put(obs)
        arr = RllibWrapper.encrypt(obj_ref)

        self.dummy = arr

    def sample(self):
        return self.dummy

    def contains(self, x):
        return True

class RllibWrapper(Wrapper):
    def __init__(self, env: Env) -> None:
        if not ray.is_initialized():
            raise RuntimeError('Ray must be initialized before using RllibWrapper')

        super().__init__(env)
        # Repeated Values too inefficient
        # So we bypass by storing object into ObjectStore
        self.observation_space = ObjRefSpace()

    @staticmethod
    def encrypt(obj_ref: ray.ObjectRef) -> np.ndarray:
        arr = np.array([i for i in obj_ref.binary()], dtype=np.uint8)

        return arr

    @staticmethod
    def decrypt(arr: np.ndarray) -> ObjectRef:
        # obj = ObjectRef(arr.tobytes())
        obj = ObjectRef(arr.tobytes())

        return obj

    def _process_obs(self, obs):
        obs = parse_graph(obs)
        obj_ref = ray.put(obs)

        arr = RllibWrapper.encrypt(obj_ref)

        return arr

    def reset(self, **kwargs):
        obs = self.env.reset(**kwargs)

        return self._process_obs(obs)

    def step(self, action):
        obs, rew, done, info = self.env.step(action)
        obs = self._process_obs(obs)

        return obs, rew, done, info

Policy Side

    @override(TorchModelV2)
    def forward(self, input_dict, state, seq_lens):
        batch = []
        for encrypted_arr in input_dict['obs']:
            arr = encrypted_arr.detach().cpu().numpy()
            obj_ref = RllibWrapper.decrypt(arr)
            arr = ray.get(obj_ref, timeout=3)

Hm, interesting, not sure why that blocks. Maybe the object references mess up somehow when you encrypt/decrypt?
Does this work with a simple script without any of the RLlib mechanisms?
E.g.:

obj = list(range(10))
obj_ref = encrypt(ray.put(obj))
obj_back = ray.get(decrypt(obj_ref))
assert obj == obj_back

yes it works, above snippet is test-driven code!

Now I’m working on file write/read -ing strategy,
it’s getting really hacky :frowning:

and even if we get to work with object store
I expect file I/O would be faster since object store requires network I/O

correct me if I’m wrong!

@ray asking the team,
Can calling ray.get inside remote policy cause trouble?

p.s. I’m just a high-level API using (for distributed computing)
researcher, not really familiar with blocking stuffs!

Thanks!
Anthony