How to properly restore checkpoint when using Pytorch Lightning?

I have a ray tune analysis object and I am able to get the best checkpoint from it:

analysis = tune_robert_asha(num_samples=2)
best_ckpt = analysis.best_checkpoint

But I am unable to restore my pytorch lightning model with it.

I try:

MyLightningModel.load_from_checkpoint(
    os.path.join(analysis.best_checkpoint, "checkpoint")
)

But MyLightningModel exposes a config in its constructor so that ray tune can set certain hyperparameters for each trial:

class MyLightningModel (pl.LightningModule):
    def __init__(self, config=None):
        self.lr = config["lr"]
        self.batch_size = config["batch_size"]
        self.layer_size = config["layer_size"]

        super(MyLightningModel , self).__init__()
        self.lstm = nn.LSTM(768, self.layer_size, num_layers=1, bidirectional=False)
        self.out = nn.Linear(self.layer_size, 1)

Thus when I try to run load_from_checkpoint, I get an error in the constructor of MyLightningModel since the config is undefined:


TypeError Traceback (most recent call last)
in ()
1 MyLightningModel.load_from_checkpoint(
----> 2 os.path.join(analysis.best_checkpoint, “checkpoint”)
3 )

2 frames
in init(self, config)
3 def init(self, config=None):
4
----> 5 self.lr = config[“lr”]
6 self.batch_size = config[“batch_size”]
7 self.layer_size = config[“layer_size”]

TypeError: ‘NoneType’ object is not subscriptable

How is this supposed to be handled?

1 Like

@Luca_Guarro you can also pass in config=config when you call load_from_checkpoint, and it will get propagated to the Lightning Module constructor.

1 Like

@amogkam could you tell what params should be passed in the config, 1. the best found or 2. could be anything and the best model checkpoint will automatically load them?

You should probably pass in the best found which you can get from analysis.get_best_config()

1 Like