How to properly restore checkpoint when using Pytorch Lightning?

I have a ray tune analysis object and I am able to get the best checkpoint from it:

analysis = tune_robert_asha(num_samples=2)
best_ckpt = analysis.best_checkpoint

But I am unable to restore my pytorch lightning model with it.

I try:

MyLightningModel.load_from_checkpoint(
    os.path.join(analysis.best_checkpoint, "checkpoint")
)

But MyLightningModel exposes a config in its constructor so that ray tune can set certain hyperparameters for each trial:

class MyLightningModel (pl.LightningModule):
    def __init__(self, config=None):
        self.lr = config["lr"]
        self.batch_size = config["batch_size"]
        self.layer_size = config["layer_size"]

        super(MyLightningModel , self).__init__()
        self.lstm = nn.LSTM(768, self.layer_size, num_layers=1, bidirectional=False)
        self.out = nn.Linear(self.layer_size, 1)

Thus when I try to run load_from_checkpoint, I get an error in the constructor of MyLightningModel since the config is undefined:


TypeError Traceback (most recent call last)
in ()
1 MyLightningModel.load_from_checkpoint(
----> 2 os.path.join(analysis.best_checkpoint, “checkpoint”)
3 )

2 frames
in init(self, config)
3 def init(self, config=None):
4
----> 5 self.lr = config[“lr”]
6 self.batch_size = config[“batch_size”]
7 self.layer_size = config[“layer_size”]

TypeError: ‘NoneType’ object is not subscriptable

How is this supposed to be handled?

1 Like

@Luca_Guarro you can also pass in config=config when you call load_from_checkpoint, and it will get propagated to the Lightning Module constructor.

1 Like