I was going through the ray sgd documentation and came across the segment where the TorchOperator based class can be converted to a tune Trainable supported class calling the as_Trainable method. Is it then possible to use the save_checkpoint method as well? (ray version - 1.0.1). I would like to have access to a method that returns a checkpoint path rather than just the state dict, which is the case with the TorchOperator
Hi @Gautham_Krishnan , the trainable
returned by as_trainable
is actually a class. So you can subclass that trainable to do whatever you want (including overriding save_checkpoint).
Hope that helps!
Can you actually provide a bit more context about what you’re trying to do?
i.e., it’d probably be more effective if you showed me the code that currently doesn’t work (so that I understand your intent), than for me to show you code that might not actually be relevant
class TrainDistributed(TrainingOperator):
def setup(self, config):
<Some init code>
# Register model, optimizer, and loss.
self.model, self.optimizer, self.criterion = self.register(
models=model,
optimizers=optimizer,
criterion=loss
)
# Register data loaders.
self.register_data(train_loader=train_loader, validation_loader=val_loader)
def train_batch(self, batch, batch_info):
<over ride train_batch of training_operator>
return packet
def validate_batch(self, batch, batch_info):
<over ride val_batch of training operator>
return packet
def save_checkpoint(self, checkpoint_dir):
<I want to define this here and it should be recognized by tune.
Right now it doesn't happen and uses the standard state_dict
provided by TorchTrainer. Namely, I can override state_dict,
but I would like to override this.>
return checkpoint_dir
def main():
ray.init()
TorchTrainable = TorchTrainer.as_trainable(
training_operator_cls=TrainDistributed,
use_gpu=True,
num_workers=2,
use_tqdm=True
)
analysis = tune.run(
TorchTrainable,
config=json.load(open('<path>')),
stop={"training_iteration": 5},
checkpoint_freq=2
)
Hope this helps. This saves the wts as required. But I want to save other files relevant to a checkpoint within the checkpoint folder, hence would like the path accessed rather than just the state dict
Can you override def state_dict
of the TrainingOperator and then inside the state_dict function, do:
def state_dict(...):
super_state = super().state_dict()
my_file = write_my_file_to_a_tempdir()
out = io.BytesIO()
out.write(open(my_file).read())
super_state["file_contents"] = out.getvalue()
Sure, will try this out. As a follow up, let’s say I already subclassed tune.trainable and using it for a single gpu training. Is there a way to directly plug in this class to TorchTrainer and enable Distributed training? I would like to add some code to def cleanup as well and I don’t see how that can be done with this flow @rliaw
If you already have a training function (not trainable), and are using it for single gpu training, you can use Tune’s Distributed Torch integrations instead
Here’s an example of how you might use this: ddp_mnist_torch — Ray v2.0.0.dev0
If I use the function approach, looks like I’ll have to call pytorch DDP on my own? Asking this looking at the example link^. As well as converting model to model.module? Would appreciate it if you can point me to the differences between using and not using ray-sgd for distributed training. Thanks @rliaw
Yeah, you’ll have to call DDP on your own.
The main difference for using this vs raysgd is that you don’t have to restructure your standard training loop (while in torchtrainer you have to restructure your code into a training operator).
It also supports arbitrary checkpoints here, as you can see from the example. Hope that helps!
Okay got it. Would I have to handle combining the results of each worker in the function approach? Or is that handled internally?
I think only the rank 0 worker results are reported.
Optionally, if you want, you can use one of the torch aggregation primitives (like torch.distributed.reduce Distributed communication package - torch.distributed — PyTorch 1.7.0 documentation) to aggregate results.
Thank you. I would have to call Distributed Data Sampler on my own too right, if I want to lock subsets of data to each worker? This has not been spoken about in the example I think?
yeah, you would have to do that.