How to get mode summary if I use tune.run()?

If I use RLlib, I can get the model summary through

>>> trainer = PPOTrainer(env="CartPole-v0", config={"eager": True, "num_workers": 0})
>>> policy = trainer.get_policy()
>>> policy.model.base_model.summary()
Model: "model"
_____________________________________________________________________
Layer (type)               Output Shape  Param #  Connected to
=====================================================================
observations (InputLayer)  [(None, 4)]   0
_____________________________________________________________________
fc_1 (Dense)               (None, 256)   1280     observations[0][0]
_____________________________________________________________________
fc_value_1 (Dense)         (None, 256)   1280     observations[0][0]
_____________________________________________________________________
fc_2 (Dense)               (None, 256)   65792    fc_1[0][0]
_____________________________________________________________________
fc_value_2 (Dense)         (None, 256)   65792    fc_value_1[0][0]
_____________________________________________________________________
fc_out (Dense)             (None, 2)     514      fc_2[0][0]
_____________________________________________________________________
value_out (Dense)          (None, 1)     257      fc_value_2[0][0]
=====================================================================
Total params: 134,915
Trainable params: 134,915
Non-trainable params: 0

But if I use the tune.run(), there isn’t a trainer, so How can I get the model summary?

Hey @bug404 , unfortunately, there is no way of getting the Trainer object directly from the tune.run() results. But you could do the following after tune.run:

results = tune.run("PPO", config=[some config], checkpoint_at_end=True)
checkpoint = results.get_last_checkpoint()

# Create a new Trainer.
trainer = PPOTrainer(config=[same config as above])
trainer.restore(checkpoint)
trainer.get_policy().model.base_model.summary()

Actually, you don’t even have to restore, the model summary would be the same for a freshly init’d model.

Yeah, I think I can create a new PPOTrainer and pass the config parameter to it, and then call get_policy(). I asked the question was to see if there is a better way to do it.

Sorry, no, at this point, there isn’t. :confused:

@bug404,

You could also do something like this. I don’t know if it counts as “a better way” but it would work. I do not think that updating after_init on the Trainer will step on any other callbacks because I think those are defined in the policies rather than the trainer but @sven1977 would know better.

    from ray.tune.registry import register_trainable
    from ray.rllib.agents.ppo import PPOTrainer
    ModelPrintPPOTrainer = PPOTrainer.with_updates(after_init=lambda trainer: trainer.get_policy().model.base_model.summary())
    register_trainable("ModelPrintPPOTrainer", ModelPrintPPOTrainer)

   tune.run("ModelPrintPPOTrainer",...)

ray/custom_fast_model.py at master · ray-project/ray · GitHub,
I modified this code as yours way, but it didn’t work. Can you have a look?

The code you were using before will only work on a tensolfow model that is using keras and has stored that model as a “base_model” variable in the model class. If you are using torch you would need a different method to print the model. If you are using tf but not keras, which is what is happening in the example you posted, then you would need a different method to print a summary. If you were using a model that did use keras but did not assign it to the “base_model” member variable you would have to update that.

The approach of using after_init to print the summary should work in all those cases but the actual code in the lambda or function you define will have to be model specific.

Wow, it’s very useful. Thank you very much.

 # TODO: (sven): Get rid of `get_initial_state` once Trajectory

Sven in this code refers to you? :open_mouth:

Thanks @mannyv for your detailed answer and the really cool hack into the after_init :slight_smile:
This is great. @bug404 I answered your question on how to do this for torch in another thread here in the forum:

Yeah, I have seen it. Thank you very much.