Hi! I have a subclass of MultiAgentEnv
that obeys the gym-like step/reset paradigm. I’m wondering if I switch all the numpy in the environment over to jax and run on gpu, will ray and gym capitalize on jax.numpy.DeviceArray for speed?
Ray expects the observations to be numpy.ndarray
, and calls out to numpy
methods. I suspect using jax.numpy.array
will cause failures, unless they can be operated on by numpy.ndarray
methods.
This begs the question whether it is worth moving the core ray logic from numpy
to jax
, but I suspect this is a lot of work.