I see some preliminary work has been done to create a JAXPolicy within RLlib as well as a couple of PRs from a year or two ago. What is the current state of this work? Is it currently possible to use JAX with RLlib? The most relevant issue I found was this one but it doesn’t say what the current status of this is. Would love to know how possible this is currently or if this is a project I could help contribute to.
Thanks
any and all contributions are always welcomed!
this is not on our immediate road map because we haven’t seen a strong demand of Jax backed Policies.
but we understand that Jax is becoming more and more popular, so it is likely just a matter of time before we need to support it.
can you say a bit more about your use case, and why TF or PyTorch don’t satisfy your needs?
The case for JAX implementations of both policy and environment in RL is really strong. If it’s all JAX, we can jit-compile entire rollouts (or large parts of it) and take maximal advantage of hardware accelerators by minimising context switch between CPU and the accelerator. I’m only just starting to look into Ray/RLlib and I’m wondering how far it can be pushed in this distributed setting. Jit-compilation adds overhead first time a function is called, so it would restrict the benefits severely if it had to compile on every episode. I assume the RolloutWorkers live throughout the training and just communicate back and forth with the main node since the environments have to keep their state?