Issue with LightningConfigBuilder

Hello,

I am following this tutorial to implement ray air for pytorch lightning tuning.

My code is pasted below. The code fails during:

LightningConfigBuilder()
.module(cls=BiomassModel, **cfg)

With the error message:

“ValueError: ‘module_class’ must be a subclass of ‘pl.LightningModule’!”

However, as can be seen in my code, I check
issubclass(BiomassModel, pl.LightningModule)

Before static_lightning_config and this returns True. It appears that for some reason once BiomassModel is within the .module method it no longer is a subclass of pl.LightningModule.

Any idea what is happening here?

from ray.train.lightning import LightningTrainer, LightningConfigBuilder
from ray import air, tune
from ray.air.config import RunConfig, ScalingConfig, CheckpointConfig
from ray.tune.schedulers import ASHAScheduler
import yaml
import lightning.pytorch as pl
import torch
from torchmetrics import R2Score
from torchmetrics.functional import r2_score, mean_squared_error

# My code
from data import PointCloudDataModule
from utils.get_model import get_model
from utils.training_utils import loss_fn
from utils.data_utils import re_convert_to_Mg_ha, update_z_score_conversion_info

logger = None

num_epochs = 3

with open("config.yaml", "r") as yamlfile:
    cfg = yaml.load(yamlfile, Loader=yaml.FullLoader)

data_module = PointCloudDataModule(cfg=cfg)

target_metric = "val_loss"

direction = "min"

n_trials = 3

class BiomassModel(pl.LightningModule):
    def __init__(self, cfg):
        super(BiomassModel, self).__init__()

        #Attach all arguments to lightning module
        self.cfg = cfg

        # Update mean and sd and save for z-score conversion
        self.z_info = update_z_score_conversion_info(cfg['data_dir'])

        # Set up model
        self.model = get_model(cfg)

        # Set up evaluation metrics with torchmetrics
        self.train_r2 = R2Score(num_outputs=4, adjusted=0, multioutput='uniform_average', dist_sync_on_step=False)
        self.val_r2 = R2Score(num_outputs=4, adjusted=0, multioutput='uniform_average', dist_sync_on_step=False)
        self.test_r2 = R2Score(num_outputs=4, adjusted=0, multioutput='uniform_average', dist_sync_on_step=False)

    def training_step(self, batch, batch_idx):
        # Forward pass
        pred = self.model(batch)

        # Calculate loss
        train_loss = loss_fn(pred=pred, y=batch['target'], loss_function_type=self.cfg['loss_function_type'])

        # Convert pred and y from z-score to Mg/ha value to compute R^2
        pred = re_convert_to_Mg_ha(self.z_info, z_components_arr=pred)
        y = re_convert_to_Mg_ha(self.z_info, z_components_arr=batch['target'])
        train_r2 = self.train_r2(pred, y)

        # Log metrics
        self.log("train_loss", value=train_loss, batch_size=self.cfg['batch_size'],
                 on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)
        self.log("train_r2", value=train_r2, batch_size=self.cfg['batch_size'],
                 on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)

        return train_loss

    def validation_step(self, batch, batch_idx):
        # Forward pass
        pred = self.model(batch)

        # Calculate loss
        val_loss = loss_fn(pred=pred, y=batch['target'], loss_function_type=self.cfg['loss_function_type'])

        # Convert pred and y from z-score to Mg/ha value to compute R^2
        pred = re_convert_to_Mg_ha(self.z_info, z_components_arr=pred)
        y = re_convert_to_Mg_ha(self.z_info, z_components_arr=batch['target'])
        val_r2 = self.val_r2(pred, y)

        # Log metrics
        self.log("val_loss", value=val_loss, batch_size=self.cfg['batch_size'],
                 on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)
        self.log("val_r2", value=val_r2, batch_size=self.cfg['batch_size'],
                 on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)

        return val_loss

    def test_step(self, batch, batch_idx):
        # Forward pass
        pred = self.model(batch)

        # Calculate loss
        test_loss = loss_fn(pred=pred, y=batch['target'], loss_function_type=self.cfg['loss_function_type'])

        # Convert pred and y from z-score to Mg/ha value to compute R^2
        pred = re_convert_to_Mg_ha(self.z_info, z_components_arr=pred)
        y = re_convert_to_Mg_ha(self.z_info, z_components_arr=batch['target'])

        # Get loss of total AGB
        tree_pred = pred[:, 0] + pred[:, 1] + pred[:, 2] + pred[:, 3]
        tree_obs = y[:, 0] + y[:, 1] + y[:, 2] + y[:, 3]

        # Calculate metrics for component
        comp_list = ['bark', 'branch', 'foliage', 'wood']
        idx_list = [0, 1, 2, 3]
        test_metric_dict = dict()
        for comp, idx in zip(comp_list, idx_list):
            test_metric_dict[comp + "_r2"] = r2_score(preds=pred[:, idx], target=y[:, idx])
            test_metric_dict[comp + "_rmse"] = torch.sqrt(mean_squared_error(preds=pred[:, idx], target=y[:, idx]))

        # Calculate metrics for tree and overall
        test_metric_dict['tree_r2'] = r2_score(preds=tree_pred, target=tree_obs)
        test_metric_dict['overall_r2'] = r2_score(preds=pred, target=y, adjusted=0, multioutput='uniform_average')
        test_metric_dict['tree_mse'] = torch.sqrt(mean_squared_error(preds=tree_pred, target=tree_obs))
        test_metric_dict['overall_mse'] = torch.sqrt(mean_squared_error(preds=pred, target=y))

        # Log metrics
        self.log_dict(test_metric_dict, batch_size=self.cfg['batch_size'])

        return test_loss

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(params=self.model.parameters(), lr=self.cfg['lr'],
                                      weight_decay=self.cfg['weight_decay'])
        lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer,
                                                                            T_0=self.cfg['cawr_t_0'],
                                                                            T_mult=self.cfg['cawr_t_mult']
                                                                            )
        return [optimizer], [lr_scheduler]


issubclass(BiomassModel, pl.LightningModule)

# Static configs that does not change across trials
static_lightning_config = (
    LightningConfigBuilder()
    .module(cls=BiomassModel, **cfg)
    .trainer(max_epochs=num_epochs, accelerator='gpu', logger=logger)
    .fit_params(datamodule=data_module)
    .checkpointing(monitor=target_metric, 
                   save_top_k=0, #No models are saved during HP tuning
                   mode=direction)
    .build()
)

# Searchable configs across different trials
searchable_lightning_config = (
    LightningConfigBuilder()
    .module(config={
        "lr": tune.loguniform(1e-4, 1e-1),
    })
    .build()
)

# Make sure to also define an AIR CheckpointConfig here
# to properly save checkpoints in AIR format.
run_config = RunConfig(
    checkpoint_config=CheckpointConfig(
        num_to_keep=0,
        checkpoint_score_attribute=target_metric,
        checkpoint_score_order=direction,
    ),
)

scheduler = ASHAScheduler(max_t=num_epochs, grace_period=1, reduction_factor=2)

scaling_config = ScalingConfig(
    num_workers=3, use_gpu=True, resources_per_worker={"CPU": 1, "GPU": 2}
)

# Define a base LightningTrainer without hyper-parameters for Tuner
lightning_trainer = LightningTrainer(
    lightning_config=static_lightning_config,
    scaling_config=scaling_config,
    run_config=run_config,
)


scheduler = ASHAScheduler(max_t=num_epochs, grace_period=1, reduction_factor=2)

tuner = tune.Tuner(
    lightning_trainer,
    param_space={"lightning_config": searchable_lightning_config},
    tune_config=tune.TuneConfig(
        metric=target_metric,
        mode=direction,
        num_samples=n_trials,
        scheduler=scheduler,
    ),
    run_config=air.RunConfig(
        name="ray_tune_test",
    ),
)
results = tuner.fit()
best_result = results.get_best_result(metric=target_metric, mode=direction)
print(best_result)

Hey @Harry_Seely , does this issue happen if you run this without Tune, if you just run lightning_trainer.fit()?

Hey @Harry_Seely Can you try to use import pytorch_lightning as pl. We haven’t update the import path so that may raise an error.

1 Like

thanks @matthewdeng and @yunxuanx for jumping in.

Thank you, this solved the problem!