Ray Train and sharded models with Jax--possible? Desirable?

Got a case where I’m trying to help scale training of a large model that won’t fit on a single GPU to be trained with a large data set (a familiar story…). The model in this case is using Jax, so it’s not as obvious how to modify it as if it were, say, a Pytorch Lightning setup. The tutorials I see using Jax are pretty jax-heavy (using jax.distributed or create_device_mesh etc).

Is it possible or remotely desirable to try and mix this with Ray? I’ve used ray.train in a more torch-driven environment and found the interaction to be great, but it’s less clear how it can interact with Jax, and I see relatively few examples…

Hey @vputz. Did you give this a shot by any chance? I’m thinking about doing the same thing but didn’t find too much out there on how jax.distributed.initialize() launches its distributed runtime and whether this might conflict with Ray’s GCS & distributed scheduler model.

So! Yeeeesss… although we haven’t stress tested it by any means, which is probably where the devil is in the details.

What I wound up doing is taking a yet-unmerged bit of code ([Ray Air] adding the jax trainer by JiahaoYao · Pull Request #27106 · ray-project/ray · GitHub) and using pieces of it. Since it wasn’t merged and the PR isn’t necessarily clear why, I’d proceed with a little care!

I can’t share what we did for contractual reasons, but basically it’s very close to that PR; look at the python/ray/train/jax/config.py file with the JaxConfig class, the setup_jax_gpu_environment call (which basically gets the master address, num_workers, etc and calls jax.distributed.initialize(), and the _JaxBackend classes (we’re just using GPUs so the TPU stuff in there I just cut out). The python/ray/train/jax/jax_trainer.py file has a JaxTrainer which then handles the data parallel training should you choose. I was able to replicate his MNIST-style training example using Ray Data and it ran to completion and got the same results as single-GPU training once batch sizes etc were tweaked.

So … yes with caveats. This didn’t do anything cosmic at all; basically Ray just sets up the workers and launches them, and then jax.distributed.initialize connects them and you can do the work with things like jax’s pmap to average out gradients (for data distributed parallel training using optax as opposed to ensembles or whatnot)

    ...
    grads = jax.grad(loss)(state.params, batch)
    grads = dict_mean(grads)
    ...

to average the gradients, where dict_mean is something like

    leaves, treedef = jax.tree_util.tree_flatten(d)
    new_leaves = [
        jax.pmap(lambda x: jax.lax.pmean(x, axis_name='i'), axis
            jax.pmap(lambda x: jax.lax.pmean(x, axis_name='i')(jax.numpy.expand_dims(leaf,0))[0]
            for leaf in leaves
        ]
        return jax.tree_util.tree_unflatten(treedef, new_leaves)

…or something along those lines depending on your training needs. We haven’t gotten as far as using it in practice and I’m not sure we will for a while, but in principle it worked–again, this relies mostly on Ray just to do the coordination and once you get jax.distributed.initialize() working it’s on you to get the training going the way you want it.

…and of course with mnist it’s hilariously slower to do it that way because it’s a toy problem, and the feedback loop with Ray on a cluster makes debugging tedious, particularly with jax and multi-node work.

But it DID seem to work; it’s not clear why that wasn’t merged and I hope something official IS merged at some point, but the basics worked for what we were doing. I’m looking forward to using it for something more serious; right now the only ray/jax stuff we’ve used even at small non-demo scale is some embarassingly parallel stuff where it’s just individual nodes not communicating.