My understanding of PB2
and PopulationBasedTraining
is that they use checkpoints from other trials when initializing a model created for a new mutation.
As shown in this guide, one has to use an if session.get_checkpoint():
conditional statement to see if the new trial will use a checkpoint from an earlier trial.
I was wondering if there is a way to get the trial_dir
or trial_id
of the trial from which the checkpoint is loaded. My training function supports storing and loading checkpoints, and I would like to avoid the overhead of storing the same checkpoints with session.report()
. My training function would then look like this:
if session.get_checkpoint(): # or some other flag
# Assuming session.get_source_trial_dir()
# or something similar exits
source_trial_dir = session.get_source_trial_dir()
# By setting args.checkpoint_path,
# the training function will load the weights
# (and optionally, optimizer parameters),
# from the specified checkpoint.
args.checkpoint_path = source_trial_dir + "checkpoint.pth.tar"