How to get mode summary if I use tune.run()?

If I use RLlib, I can get the model summary through

>>> trainer = PPOTrainer(env="CartPole-v0", config={"eager": True, "num_workers": 0})
>>> policy = trainer.get_policy()
>>> policy.model.base_model.summary()
Model: "model"
_____________________________________________________________________
Layer (type)               Output Shape  Param #  Connected to
=====================================================================
observations (InputLayer)  [(None, 4)]   0
_____________________________________________________________________
fc_1 (Dense)               (None, 256)   1280     observations[0][0]
_____________________________________________________________________
fc_value_1 (Dense)         (None, 256)   1280     observations[0][0]
_____________________________________________________________________
fc_2 (Dense)               (None, 256)   65792    fc_1[0][0]
_____________________________________________________________________
fc_value_2 (Dense)         (None, 256)   65792    fc_value_1[0][0]
_____________________________________________________________________
fc_out (Dense)             (None, 2)     514      fc_2[0][0]
_____________________________________________________________________
value_out (Dense)          (None, 1)     257      fc_value_2[0][0]
=====================================================================
Total params: 134,915
Trainable params: 134,915
Non-trainable params: 0

But if I use the tune.run(), there isn’t a trainer, so How can I get the model summary?

Hey @bug404 , unfortunately, there is no way of getting the Trainer object directly from the tune.run() results. But you could do the following after tune.run:

results = tune.run("PPO", config=[some config], checkpoint_at_end=True)
checkpoint = results.get_last_checkpoint()

# Create a new Trainer.
trainer = PPOTrainer(config=[same config as above])
trainer.restore(checkpoint)
trainer.get_policy().model.base_model.summary()
1 Like

Actually, you don’t even have to restore, the model summary would be the same for a freshly init’d model.

1 Like

Yeah, I think I can create a new PPOTrainer and pass the config parameter to it, and then call get_policy(). I asked the question was to see if there is a better way to do it.

1 Like

Sorry, no, at this point, there isn’t. :confused:

@bug404,

You could also do something like this. I don’t know if it counts as “a better way” but it would work. I do not think that updating after_init on the Trainer will step on any other callbacks because I think those are defined in the policies rather than the trainer but @sven1977 would know better.

    from ray.tune.registry import register_trainable
    from ray.rllib.agents.ppo import PPOTrainer
    ModelPrintPPOTrainer = PPOTrainer.with_updates(after_init=lambda trainer: trainer.get_policy().model.base_model.summary())
    register_trainable("ModelPrintPPOTrainer", ModelPrintPPOTrainer)

   tune.run("ModelPrintPPOTrainer",...)
2 Likes

ray/custom_fast_model.py at master · ray-project/ray · GitHub,
I modified this code as yours way, but it didn’t work. Can you have a look?

The code you were using before will only work on a tensolfow model that is using keras and has stored that model as a “base_model” variable in the model class. If you are using torch you would need a different method to print the model. If you are using tf but not keras, which is what is happening in the example you posted, then you would need a different method to print a summary. If you were using a model that did use keras but did not assign it to the “base_model” member variable you would have to update that.

The approach of using after_init to print the summary should work in all those cases but the actual code in the lambda or function you define will have to be model specific.

2 Likes

Wow, it’s very useful. Thank you very much.

 # TODO: (sven): Get rid of `get_initial_state` once Trajectory

Sven in this code refers to you? :open_mouth:

Thanks @mannyv for your detailed answer and the really cool hack into the after_init :slight_smile:
This is great. @bug404 I answered your question on how to do this for torch in another thread here in the forum:

1 Like

Yeah, I have seen it. Thank you very much.