So! Yeeeesss… although we haven’t stress tested it by any means, which is probably where the devil is in the details.
What I wound up doing is taking a yet-unmerged bit of code ([Ray Air] adding the jax trainer by JiahaoYao · Pull Request #27106 · ray-project/ray · GitHub) and using pieces of it. Since it wasn’t merged and the PR isn’t necessarily clear why, I’d proceed with a little care!
I can’t share what we did for contractual reasons, but basically it’s very close to that PR; look at the python/ray/train/jax/config.py
file with the JaxConfig
class, the setup_jax_gpu_environment
call (which basically gets the master address, num_workers, etc and calls jax.distributed.initialize()
, and the _JaxBackend
classes (we’re just using GPUs so the TPU stuff in there I just cut out). The python/ray/train/jax/jax_trainer.py
file has a JaxTrainer
which then handles the data parallel training should you choose. I was able to replicate his MNIST-style training example using Ray Data and it ran to completion and got the same results as single-GPU training once batch sizes etc were tweaked.
So … yes with caveats. This didn’t do anything cosmic at all; basically Ray just sets up the workers and launches them, and then jax.distributed.initialize
connects them and you can do the work with things like jax’s pmap
to average out gradients (for data distributed parallel training using optax as opposed to ensembles or whatnot)
...
grads = jax.grad(loss)(state.params, batch)
grads = dict_mean(grads)
...
to average the gradients, where dict_mean
is something like
leaves, treedef = jax.tree_util.tree_flatten(d)
new_leaves = [
jax.pmap(lambda x: jax.lax.pmean(x, axis_name='i'), axis
jax.pmap(lambda x: jax.lax.pmean(x, axis_name='i')(jax.numpy.expand_dims(leaf,0))[0]
for leaf in leaves
]
return jax.tree_util.tree_unflatten(treedef, new_leaves)
…or something along those lines depending on your training needs. We haven’t gotten as far as using it in practice and I’m not sure we will for a while, but in principle it worked–again, this relies mostly on Ray just to do the coordination and once you get jax.distributed.initialize()
working it’s on you to get the training going the way you want it.
…and of course with mnist it’s hilariously slower to do it that way because it’s a toy problem, and the feedback loop with Ray on a cluster makes debugging tedious, particularly with jax and multi-node work.
But it DID seem to work; it’s not clear why that wasn’t merged and I hope something official IS merged at some point, but the basics worked for what we were doing. I’m looking forward to using it for something more serious; right now the only ray/jax stuff we’ve used even at small non-demo scale is some embarassingly parallel stuff where it’s just individual nodes not communicating.