We have a model that have custom functions, example:
model = ray.train.torch.prepare_model(ExampleNetwork())
When you use the function ray.train.torch.prepare_model, the resulting model works differently depending on the number of workers we use in the TorchTrainer (ray.train.torch.TorchTrainer).
If we only use one worker, the model do not need to be parallelized and is therefore still a nn.Module, but if we use multiple workers it needs to be parallised and instead become a DistributedDataParallel (torch.nn.parallel.DistributedDataParallel).
The problem with this is that when the model is parallelized, the prepare_model-function do not wrap the custom functions, and to access these you need to change the calls from:
model.custom_function → model.module.custom_function
The standard functions like forward still works as intended.
It is not really scalable to add checks for the type of the model before every function call, so I am wondering if there is another way to prepare the model or a way to wrap the resulting model so we do not need to have different training code dependent on the number of workers we are going to be training on.