Model Parallelism in Ray

Hi folks, it seems like Ray Train focuses on distributed training with data parallelism. I am wondering if there is a use case with model parallelism. In our specific use-case, we are training large-scale embeddings, and these typically require model parallelism due to a large embedding matrix that cannot fit in the memory of one machine.

That’s a great use case. You mentioned offline that you were looking at Pytorch Biggraph in particular?

Thanks, Richard, for the quick response. Yes, in particular, we are looking to train node embeddings on large graphs, and use Pytorch-BigGraph as a framework for training over a Ray cluster.

Officially, we don’t have any pre-existing examples. However, it should work fine (given that Ray Train just constructs the process group for you).

We would be happy to help guide you through the implementation, if you have any particular questions.

(and also subsequently highlight your use case as a successful example down the road!)

Of course! Would be happy to contribute to the ray ecosystem in any which way.

The fact that Ray Train just constructs the process group makes sense. However, the examples provided in the documentation give an impression that the API is limited and works mostly for the data parallel case. After decoding PyTorch-BigGraph, I think I can train with model parallelism on Ray if only I can replace init_process_group from torch.distributed by an equivalent ray function. Is there a similar API in Ray? Thanks!

1 Like

Can you post an example of what you want to do?

Hi all,

Coming across a similar use case where we want to use Ray Train to split a large model across multiple GPUs rather than replicate (data parallel).

For example I have a cluster with A10 GPUs (24GB) but the model requires ~50GB of GPU memory. Can I use Ray Train APIs to partition the model across multiple GPUs? Are there any examples?

Thanks!

1 Like

Hey @jamjambles,

The easiest way would be to use deepspeed with a Zero-1/2/3 sharding strategy using Ray Train. See this user guide for more details: Get Started with DeepSpeed — Ray 2.7.1

Let me know if you’re able to set that up. I’m also happy to move to a slack conversation to discuss further!

@rliaw can you help me, i just wanna use ray to serve llama2-70b in 2 a10-4gpus vms.