In my application, I have multiple training workers (as Ray actors) that interact with non-JAX environments. The training workers use a JAX-based policy that outputs not only an action, but also other statistics during the policy evaluation. The output statistics along with environment outputs are stacked together and sent to an other Ray actor for a non-JAX post-processing procedure.
This process roughly looks like the following
def post_process(traj):
# do something in pure Python using numpy
return processed
@ray.remote
class Replay:
def __init__(self):
self.processed = []
def add(self, traj):
self.processed.extend(post_process(traj))
@ray.remote
class DataWorker:
def run(self):
# interact with environment using a JAX policy
traj = [(step.env.out, step.policy.statistics) for step in steps]
# step.policy.statistics are DeviceArray s
return traj
def main():
replay.add.remote(data_worker.run.remote())
The project and I am having is that the post processing is somehow really slow and is currently the bottleneck of the system. I am wondering if any one has experience in handling JAX device arrays across processes, especially when working with Ray? If so, do you have tips for building a system like this?
Here are some specific questions I would like to ask:
- Transfering numpy arrays is quite convenient in Ray with zero-copy. How does this interact with JAX device arrays?
- Should the data worker manually transfer the data back to CPU using
device_put
before letting Ray serializing the data?