I am using hydra to manage configs (the search space passed as config
, and the rest of the architecture passed as arch_cfg
), hence the use of instantiate to initialise objects
I am also using wandb_mixin with the train function. I’ve had the wandb_mixin in my regular pytorch script and that has no performance issues as with pytorch lightning.
this is my overall train function:
@wandb_mixin
def pl_train_func(config, arch_cfg, checkpoint_dir=None):
project_dir = os.getenv('TUNE_ORIG_WORKING_DIR')
# wandb is already initialised with wandb_mixin in ray, hence just need to pass in the wrapper into pl
wandb_logger = WandbLogger()
config = OmegaConf.create(config)
arch_cfg = OmegaConf.merge(arch_cfg, config)
write_conf(arch_cfg, "config.yaml")
wandb_logger.experiment.log_artifact("config.yaml", type="config")
_config = arch_cfg.copy()
with open_struct(_config):
_config.run.pop("config")
_config = OmegaConf.to_container(_config)
_config = flatten_dict(_config, delimiter='-')
wandb_logger.experiment.config.update(_config, allow_val_change=True)
if arch_cfg.get("seed"):
seed_everything(arch_cfg.seed, workers=True)
checkpoint_name = arch_cfg.get("checkpoint_name", "model_checkpoint.ckpt")
with change_directory(project_dir):
datamodule= instantiate(arch_cfg.datamodule)
output_size = len(datamodule.categories)
channel_n = datamodule.channel_n
net = instantiate(arch_cfg.model.net, output_size=output_size, channel_n=channel_n)
model= instantiate(arch_cfg.model, net=net)
wandb_logger.watch(model, log="all", log_freq=1)
wandb_logger.experiment.define_metric(name=arch_cfg.monitor, summary=arch_cfg.mode)
tune_callback = TuneReportCheckpointCallback(filename=checkpoint_name, on="validation_end",
keep_pth=arch_cfg.get("enable_tune_cktp", True))
resume_checkpoint = None
if checkpoint_dir:
resume_checkpoint = os.path.join(checkpoint_dir, checkpoint_name)
callbacks = [tune_callback]
if arch_cfg.pl_trainer.get("callbacks"):
callbacks = list(instantiate(arch_cfg.pl_trainer.callbacks)) + callbacks
trainer: Trainer = instantiate(arch_cfg.pl_trainer, callbacks=callbacks, enable_progress_bar=False,
logger=wandb_logger)
trainer.fit(model=model, datamodule=datamodule, ckpt_path=resume_checkpoint)
wandb_logger.experiment.unwatch(model)
with the topic on putting in checkpoints, I’ve actually changed the pytorch lightning TuneCheckpointCallback a little bit so that I have the option to turn off checkpointing completely (as I couldn’t do so with this issue). So this is now the _TuneCheckpointCallback class that TuneCheckpointCallback uses:
class _TuneCheckpointCallback(TuneCallback):
"""PyTorch Lightning checkpoint callback
Saves checkpoints after each validation step.
Checkpoint are currently not registered if no ``tune.report()`` call
is made afterwards. Consider using ``TuneReportCheckpointCallback``
instead.
Args:
filename: Filename of the checkpoint within the checkpoint
directory. Defaults to "checkpoint".
on: When to trigger checkpoint creations. Must be one of
the PyTorch Lightning event hooks (less the ``on_``), e.g.
"batch_start", or "train_end". Defaults to "validation_end".
"""
def __init__(
self, filename: str = "checkpoint", on="validation_end", keep_pth=True):
super(_TuneCheckpointCallback, self).__init__(on)
self._filename = filename
self.keep_pth = keep_pth
def _handle(self, trainer: Trainer, pl_module: LightningModule):
if trainer.sanity_checking:
return
step = f"epoch-{trainer.current_epoch}-step-{trainer.global_step}"
with tune.checkpoint_dir(step=step) as checkpoint_dir:
if self.keep_pth:
trainer.save_checkpoint(os.path.join(checkpoint_dir, self._filename), weights_only=True)
I’ve had to implement something similar in the regular pytorch script as well.
I save all checkpoints/logs on the local disk not to a cloud