Hey @bug404 , unfortunately, there is no way of getting the Trainer object directly from the tune.run() results. But you could do the following after tune.run:
results = tune.run("PPO", config=[some config], checkpoint_at_end=True)
checkpoint = results.get_last_checkpoint()
# Create a new Trainer.
trainer = PPOTrainer(config=[same config as above])
trainer.restore(checkpoint)
trainer.get_policy().model.base_model.summary()
Yeah, I think I can create a new PPOTrainer and pass the config parameter to it, and then call get_policy(). I asked the question was to see if there is a better way to do it.
You could also do something like this. I don’t know if it counts as “a better way” but it would work. I do not think that updating after_init on the Trainer will step on any other callbacks because I think those are defined in the policies rather than the trainer but @sven1977 would know better.
from ray.tune.registry import register_trainable
from ray.rllib.agents.ppo import PPOTrainer
ModelPrintPPOTrainer = PPOTrainer.with_updates(after_init=lambda trainer: trainer.get_policy().model.base_model.summary())
register_trainable("ModelPrintPPOTrainer", ModelPrintPPOTrainer)
tune.run("ModelPrintPPOTrainer",...)
The code you were using before will only work on a tensolfow model that is using keras and has stored that model as a “base_model” variable in the model class. If you are using torch you would need a different method to print the model. If you are using tf but not keras, which is what is happening in the example you posted, then you would need a different method to print a summary. If you were using a model that did use keras but did not assign it to the “base_model” member variable you would have to update that.
The approach of using after_init to print the summary should work in all those cases but the actual code in the lambda or function you define will have to be model specific.
Thanks @mannyv for your detailed answer and the really cool hack into the after_init
This is great. @bug404 I answered your question on how to do this for torch in another thread here in the forum: