My understanding of
PopulationBasedTraining is that they use checkpoints from other trials when initializing a model created for a new mutation.
As shown in this guide, one has to use an
if session.get_checkpoint(): conditional statement to see if the new trial will use a checkpoint from an earlier trial.
I was wondering if there is a way to get the
trial_id of the trial from which the checkpoint is loaded. My training function supports storing and loading checkpoints, and I would like to avoid the overhead of storing the same checkpoints with
session.report(). My training function would then look like this:
if session.get_checkpoint(): # or some other flag # Assuming session.get_source_trial_dir() # or something similar exits source_trial_dir = session.get_source_trial_dir() # By setting args.checkpoint_path, # the training function will load the weights # (and optionally, optimizer parameters), # from the specified checkpoint. args.checkpoint_path = source_trial_dir + "checkpoint.pth.tar"