torch.nn.DataParallel with tune.run()

Hello,

on a server I have 5 GPUs.
Imagine a single torch model does not fit on a single GPU, hence I have to use torch.nn.DataParallel. This splits the batches across all available GPUs.

Now, I want to do hyperparameter optimization using Ray (tune.run()). Even though I wrap my torch model in torch.nn.DataParallel, it seems that Ray ignores that, and still tries to use 1 GPU per model. This throws a memory error.

How can I distribute one single trial on multiple GPUs?
I tried resources_per_trial={"gpu": 5} , but no success.

Thanks.

Imagine a single torch model does not fit on a single GPU

By this do you mean the model does not fit on the GPU or a single batch does not fit on the GPU?

If you are indeed looking for data parallelism, I’d recommend checking out Ray Train! Also, in general it is recommended to use DistributedDataParallel instead of DataParallel.

If your model does not fit on the GPU, you may need to use Pipeline Parallelism.