Hello! I am trying to train a hierarchical model, and I need to train each subpolicy separately, save the weights and load them when training the top level policy. For this, I’m using a custom tf model, and a custom trainer (see below). I did manage to save the weights when checkpointing, but I cannot load them. I’ve tried overriding the load_checkpoint (but the method doesn’t get called unless I provide the path to a checkpoint file), and overriding the setup method (which fails with the error 'CustomTrainer' object has no attribute 'workers'
class CustomTrainer(PPOTrainer):
def save_checkpoint(self, checkpoint_dir):
real_checkpoint_dir = super(CustomTrainer, self).save_checkpoint(checkpoint_dir)
self.workers.local_worker().foreach_trainable_policy(lambda policy, policy_id: saveModel(checkpoint_dir, policy, policy_id))
return real_checkpoint_dir
def setup(self, config):
if "H5_MODEL_PATH" in os.environ.keys():
self.workers.local_worker().foreach_policy(
lambda policy, policy_id: loadModel(os.environ["H5_MODEL_PATH"], policy, policy_id))
def saveModel(directory, policy, policy_id):
logger.info("Saving policy {} weights to file".format(policy_id))
with policy.model.context():
with policy.get_session().as_default():
policy.model.base_model.save_weights(os.path.join(directory, "{}_weights.h5".format(policy_id)))
def loadModel(directory, policy, policy_id):
logger.info("Loading the weights from file for policy {}".format(policy_id))
with policy.model.context():
with policy.get_session().as_default():
policy.model.base_model.load_weights(os.path.join(directory, "{}_weights.h5".format(policy_id)))
The training is run using tune.run(CustomTrainer)
, with additional config present. How can I load the weights from a specified h5 file? Also, is there a way to get the list of policies from the trainer class without involving the workers?