To what extent is Rllib compatible with JAX?

Hello all! I was wondering if there’s been any progress with regards to Rllib compatibility increasing with JAX? The most recent post I was able to find regarding this was: Using RLlib with jax , however that was in 2022. I see snippets of code in the Rllib GitHub repository as well indicating maybe there’s potential compatibility, but I’m unsure to what extent everything works and I can just write a PPO implementation in JAX in Rllib or use custom models in JAX?

@cwijesundara, good question. We are now in the work of changing from the old stack - that is using mainly ModelV2 for models together with the Policy to a new stakc that is based on our new RLModule API together with EnvRunner API for sampling. In the planning of the new APIs JAX was definitely considered, but the regular frameworks have priority. If JAX gets implemented is mainly a question of how many users demand it.

You might want to create an issue on GitHub for a feature request such that other users can vote on it and tell about their use cases. No guarantee that it will actually be implemented of course.

Understood, thanks Lars!