Interaction between Ray and JAX

In my application, I have multiple training workers (as Ray actors) that interact with non-JAX environments. The training workers use a JAX-based policy that outputs not only an action, but also other statistics during the policy evaluation. The output statistics along with environment outputs are stacked together and sent to an other Ray actor for a non-JAX post-processing procedure.

This process roughly looks like the following

def post_process(traj):
    # do something in pure Python using numpy
    return processed

@ray.remote
class Replay:
    def __init__(self):
        self.processed = []

    def add(self, traj):
        self.processed.extend(post_process(traj))

@ray.remote
class DataWorker:
    def run(self):
        # interact with environment using a JAX policy
        traj = [(step.env.out, step.policy.statistics) for step in steps]
        # step.policy.statistics are DeviceArray s
        return traj

def main():
      replay.add.remote(data_worker.run.remote())

The project and I am having is that the post processing is somehow really slow and is currently the bottleneck of the system. I am wondering if any one has experience in handling JAX device arrays across processes, especially when working with Ray? If so, do you have tips for building a system like this?

Here are some specific questions I would like to ask:

  • Transfering numpy arrays is quite convenient in Ray with zero-copy. How does this interact with JAX device arrays?
  • Should the data worker manually transfer the data back to CPU using device_put before letting Ray serializing the data?

@Jimmy could you help with this?

Also cc @xwjiang2010