Can Jax speed up MultiAgentEnv implementations?

Hi! I have a subclass of MultiAgentEnv that obeys the gym-like step/reset paradigm. I’m wondering if I switch all the numpy in the environment over to jax and run on gpu, will ray and gym capitalize on jax.numpy.DeviceArray for speed?

Ray expects the observations to be numpy.ndarray, and calls out to numpy methods. I suspect using jax.numpy.array will cause failures, unless they can be operated on by numpy.ndarray methods.

This begs the question whether it is worth moving the core ray logic from numpy to jax, but I suspect this is a lot of work.