Ray Train and sharded models with Jax--possible? Desirable?

Got a case where I’m trying to help scale training of a large model that won’t fit on a single GPU to be trained with a large data set (a familiar story…). The model in this case is using Jax, so it’s not as obvious how to modify it as if it were, say, a Pytorch Lightning setup. The tutorials I see using Jax are pretty jax-heavy (using jax.distributed or create_device_mesh etc).

Is it possible or remotely desirable to try and mix this with Ray? I’ve used ray.train in a more torch-driven environment and found the interaction to be great, but it’s less clear how it can interact with Jax, and I see relatively few examples…