How to use BERT in ray cluster?

I would like to train BERT language model from scratch, I would like to use several computers with GPU cards to train the model (one GPU card with high GPU memory is very expensive, few cards with 8 GB are available). I consider splitting batches into machines and making asynchronous updates.
Does anybody have experience in ray clusters in this area or maybe can suggest some tutorial about BERT distributed training?


Hmm, maybe you could consider using RaySGD (here is an example).

Alternatively, we also have a Pytorch Lightning integration: GitHub - ray-project/ray_lightning: Pytorch Lightning Distributed Accelerators using Ray that you can use to scale training.

