Unable to access custom model functions when wrapping it with torch.prepare_model

We use:
ray 2.3.0
torch: 1.13.1

We have a model that have custom functions, example:

class ExampleNetwork(torch.nn.Module):
    def forward(self):
        ...

    def custom_function_1(self):
        ...

    def custom_function_2(self):
        ...

model = ray.train.torch.prepare_model(ExampleNetwork())

When you use the function ray.train.torch.prepare_model, the resulting model works differently depending on the number of workers we use in the TorchTrainer (ray.train.torch.TorchTrainer).

If we only use one worker, the model do not need to be parallelized and is therefore still a nn.Module, but if we use multiple workers it needs to be parallised and instead become a DistributedDataParallel (torch.nn.parallel.DistributedDataParallel).

The problem with this is that when the model is parallelized, the prepare_model-function do not wrap the custom functions, and to access these you need to change the calls from:
model.custom_function → model.module.custom_function
The standard functions like forward still works as intended.

It is not really scalable to add checks for the type of the model before every function call, so I am wondering if there is another way to prepare the model or a way to wrap the resulting model so we do not need to have different training code dependent on the number of workers we are going to be training on.

Hi, this is a really good question.

As you said, we by-pass the entire wrapping logic if world_size <= 1:

I will discuss this with the team internally and see how we can make this part of the experience better.
thanks again for the feedback.

1 Like

Thanks for he response and appreciate you taking this further.

One temporary solution I have found so far, but have not tested out extensively so might bring unintended consequences, is to do:

wrapped_model = ray.train.torch.prepare_model(ExampleNetwork())
if ray.air.session.get_world_size() > 1:
    model = wrapped_model.module
else:
    model = wrapped_model

yeah, seems like a nice workaround

Thanks @gjoliver. @AxelN , it seems you all sorted.