[Tune] How to get a reference to the trainer's model

Hi, I have Torchtrainer to setup the model, how can I get a reference to the ddp wrapped model?

 trainer = TorchTrainer(

It looks like TorchTrainer has a method to return the learned model, e.g.,

model = trainer.get_models()[0]

Is this model a reference to the trainer’s model so I can do some work with it on the driver side?

Hmm, get_models will return the right model from the training operator, yeah.


Hi @rliaw , I found it maybe wrong. It looks like it’s not the same reference. Thus I can’t access some internal state of the model outside torchtrainer.

For example

class CurrentModel(model):
    def __init__(self):
           self.model = model
     def fit(self):
           config = {"model": self.model}
           trainer = TorchTrainer(
                   training_operator_cls= TrainOperator,
                   config = config)
          #now self.model and trainer's model are different reference to the different object, which might because in TorchTrainer, model is copied onto CUDA

class TrainOperator():
     def setup(self, config):
           self.model = self.register(config['model'])

In the above example, in fit, I can’t access the actual ddp wrapped model’s runtime state, e.g., some variables in the class.

Does it because in TrainOperator, self.mode is referred to a model that has been copied onto CUDA, not the original model that was passed in with config[‘model’]?

Hmm, sorry can you provide a bit more context about what you’re trying to do?

This seems to be an interesting use case that we can better support.

When I tries to convert an existing code into distributed version using RaySGD, One way is to convert everything into a customized TrainOperator, another way is to keep some logic, e.g., metrics computing outside the TrainOperator. Initially I followed the later approach, which breaks as I posted. Now I just convert everything in the Trainoperator, seems working. But this brings a lot of coding effort in using RaySGD.

So basically, RaySGD or Ray assumes a clean boundary between driver and remote, which makes sense. But seems a bit inconvenience in converting an existing codebase.