[Tune] Loading a h5 weights file in trainer class

Hello! I am trying to train a hierarchical model, and I need to train each subpolicy separately, save the weights and load them when training the top level policy. For this, I’m using a custom tf model, and a custom trainer (see below). I did manage to save the weights when checkpointing, but I cannot load them. I’ve tried overriding the load_checkpoint (but the method doesn’t get called unless I provide the path to a checkpoint file), and overriding the setup method (which fails with the error 'CustomTrainer' object has no attribute 'workers'

class CustomTrainer(PPOTrainer):
def save_checkpoint(self, checkpoint_dir):
    real_checkpoint_dir = super(CustomTrainer, self).save_checkpoint(checkpoint_dir)

    self.workers.local_worker().foreach_trainable_policy(lambda policy, policy_id: saveModel(checkpoint_dir, policy, policy_id))

    return real_checkpoint_dir

def setup(self, config):
    if "H5_MODEL_PATH" in os.environ.keys():
        self.workers.local_worker().foreach_policy(
            lambda policy, policy_id: loadModel(os.environ["H5_MODEL_PATH"], policy, policy_id))


def saveModel(directory, policy, policy_id):
    logger.info("Saving policy {} weights to file".format(policy_id))
    with policy.model.context():
        with policy.get_session().as_default():
            policy.model.base_model.save_weights(os.path.join(directory, "{}_weights.h5".format(policy_id)))


def loadModel(directory, policy, policy_id):
    logger.info("Loading the weights from file for policy {}".format(policy_id))
    with policy.model.context():
        with policy.get_session().as_default():
            policy.model.base_model.load_weights(os.path.join(directory, "{}_weights.h5".format(policy_id)))

The training is run using tune.run(CustomTrainer), with additional config present. How can I load the weights from a specified h5 file? Also, is there a way to get the list of policies from the trainer class without involving the workers?

@amogkam could you take a look?

Calling the base setup() method in the overriden one solved the issue. Still couldn’t find a straightforward way to get the IDs of the policies, but I can manage without.