Wrap RaySGD over Tune trainable class directly

I was going through the ray sgd documentation and came across the segment where the TorchOperator based class can be converted to a tune Trainable supported class calling the as_Trainable method. Is it then possible to use the save_checkpoint method as well? (ray version - 1.0.1). I would like to have access to a method that returns a checkpoint path rather than just the state dict, which is the case with the TorchOperator

Hi @Gautham_Krishnan , the trainable returned by as_trainable is actually a class. So you can subclass that trainable to do whatever you want (including overriding save_checkpoint).

Hope that helps!

Hey I tried that, think I’m going wrong somewhere. An example snippet would help. Thanks! @rliaw

Can you actually provide a bit more context about what you’re trying to do?

i.e., it’d probably be more effective if you showed me the code that currently doesn’t work (so that I understand your intent), than for me to show you code that might not actually be relevant :slight_smile:

class TrainDistributed(TrainingOperator):
  def setup(self, config):
     <Some init code>

    # Register model, optimizer, and loss.
    self.model, self.optimizer, self.criterion = self.register(

    # Register data loaders.
    self.register_data(train_loader=train_loader, validation_loader=val_loader)

  def train_batch(self, batch, batch_info):
    <over ride train_batch of training_operator>
    return packet

  def validate_batch(self, batch, batch_info):
    <over ride val_batch of training operator>
    return packet

  def save_checkpoint(self, checkpoint_dir):
    <I want to define this here and it should be recognized by tune. 
     Right now it doesn't happen and uses the standard state_dict
    provided by TorchTrainer. Namely, I can override state_dict, 
    but I would like to override this.>
    return checkpoint_dir
 def main():


  TorchTrainable = TorchTrainer.as_trainable(
 analysis = tune.run(
    stop={"training_iteration": 5},

Hope this helps. This saves the wts as required. But I want to save other files relevant to a checkpoint within the checkpoint folder, hence would like the path accessed rather than just the state dict

Can you override def state_dict of the TrainingOperator and then inside the state_dict function, do:

def state_dict(...):
    super_state = super().state_dict()
    my_file = write_my_file_to_a_tempdir()
    out = io.BytesIO()
    super_state["file_contents"] = out.getvalue()

Sure, will try this out. As a follow up, let’s say I already subclassed tune.trainable and using it for a single gpu training. Is there a way to directly plug in this class to TorchTrainer and enable Distributed training? I would like to add some code to def cleanup as well and I don’t see how that can be done with this flow @rliaw

If you already have a training function (not trainable), and are using it for single gpu training, you can use Tune’s Distributed Torch integrations instead

Here’s an example of how you might use this: ddp_mnist_torch — Ray v2.0.0.dev0

If I use the function approach, looks like I’ll have to call pytorch DDP on my own? Asking this looking at the example link^. As well as converting model to model.module? Would appreciate it if you can point me to the differences between using and not using ray-sgd for distributed training. Thanks @rliaw

Yeah, you’ll have to call DDP on your own.

The main difference for using this vs raysgd is that you don’t have to restructure your standard training loop (while in torchtrainer you have to restructure your code into a training operator).

It also supports arbitrary checkpoints here, as you can see from the example. Hope that helps!

Okay got it. Would I have to handle combining the results of each worker in the function approach? Or is that handled internally?

I think only the rank 0 worker results are reported.

Optionally, if you want, you can use one of the torch aggregation primitives (like torch.distributed.reduce Distributed communication package - torch.distributed — PyTorch 1.7.0 documentation) to aggregate results.

Thank you. I would have to call Distributed Data Sampler on my own too right, if I want to lock subsets of data to each worker? This has not been spoken about in the example I think?

yeah, you would have to do that.