I am running hyperparameter-tunning for really long runs and I would like to evaluate each trial best checkpoint on a custom testbench. Currently, I am waiting for the end of the hyperparameter sweep and only then I can get the best policy by episode_reward_mean. I would like to do this for each trial.
analysis = tune.run(
PPOTrainer,
...
)
config = analysis.best_config
config["explore"] = False
agent = PPOTrainer(
env="compiler_gym",
config=config
)
agent.restore(analysis.best_checkpoint)
policy = agent.get_policy().model
eval_policy(policy)
This is what I am trying to do.
class MyCallback(Callback):
def on_trial_result(self, iteration, trials, trial, result, **info):
policy = trial.checkpoint...???? <<<<<<<<<<<<< How to extract policy network from trial?
eval_policy(policy)
analysis = tune.run(
PPOTrainer,
...,
callbacks=[ MyCallback() ]
)
Thanks!