Hello,
I am following this tutorial to implement ray air for pytorch lightning tuning.
My code is pasted below. The code fails during:
LightningConfigBuilder()
.module(cls=BiomassModel, **cfg)
With the error message:
“ValueError: ‘module_class’ must be a subclass of ‘pl.LightningModule’!”
However, as can be seen in my code, I check
issubclass(BiomassModel, pl.LightningModule)
Before static_lightning_config and this returns True. It appears that for some reason once BiomassModel is within the .module method it no longer is a subclass of pl.LightningModule.
Any idea what is happening here?
from ray.train.lightning import LightningTrainer, LightningConfigBuilder
from ray import air, tune
from ray.air.config import RunConfig, ScalingConfig, CheckpointConfig
from ray.tune.schedulers import ASHAScheduler
import yaml
import lightning.pytorch as pl
import torch
from torchmetrics import R2Score
from torchmetrics.functional import r2_score, mean_squared_error
# My code
from data import PointCloudDataModule
from utils.get_model import get_model
from utils.training_utils import loss_fn
from utils.data_utils import re_convert_to_Mg_ha, update_z_score_conversion_info
logger = None
num_epochs = 3
with open("config.yaml", "r") as yamlfile:
cfg = yaml.load(yamlfile, Loader=yaml.FullLoader)
data_module = PointCloudDataModule(cfg=cfg)
target_metric = "val_loss"
direction = "min"
n_trials = 3
class BiomassModel(pl.LightningModule):
def __init__(self, cfg):
super(BiomassModel, self).__init__()
#Attach all arguments to lightning module
self.cfg = cfg
# Update mean and sd and save for z-score conversion
self.z_info = update_z_score_conversion_info(cfg['data_dir'])
# Set up model
self.model = get_model(cfg)
# Set up evaluation metrics with torchmetrics
self.train_r2 = R2Score(num_outputs=4, adjusted=0, multioutput='uniform_average', dist_sync_on_step=False)
self.val_r2 = R2Score(num_outputs=4, adjusted=0, multioutput='uniform_average', dist_sync_on_step=False)
self.test_r2 = R2Score(num_outputs=4, adjusted=0, multioutput='uniform_average', dist_sync_on_step=False)
def training_step(self, batch, batch_idx):
# Forward pass
pred = self.model(batch)
# Calculate loss
train_loss = loss_fn(pred=pred, y=batch['target'], loss_function_type=self.cfg['loss_function_type'])
# Convert pred and y from z-score to Mg/ha value to compute R^2
pred = re_convert_to_Mg_ha(self.z_info, z_components_arr=pred)
y = re_convert_to_Mg_ha(self.z_info, z_components_arr=batch['target'])
train_r2 = self.train_r2(pred, y)
# Log metrics
self.log("train_loss", value=train_loss, batch_size=self.cfg['batch_size'],
on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)
self.log("train_r2", value=train_r2, batch_size=self.cfg['batch_size'],
on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)
return train_loss
def validation_step(self, batch, batch_idx):
# Forward pass
pred = self.model(batch)
# Calculate loss
val_loss = loss_fn(pred=pred, y=batch['target'], loss_function_type=self.cfg['loss_function_type'])
# Convert pred and y from z-score to Mg/ha value to compute R^2
pred = re_convert_to_Mg_ha(self.z_info, z_components_arr=pred)
y = re_convert_to_Mg_ha(self.z_info, z_components_arr=batch['target'])
val_r2 = self.val_r2(pred, y)
# Log metrics
self.log("val_loss", value=val_loss, batch_size=self.cfg['batch_size'],
on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)
self.log("val_r2", value=val_r2, batch_size=self.cfg['batch_size'],
on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)
return val_loss
def test_step(self, batch, batch_idx):
# Forward pass
pred = self.model(batch)
# Calculate loss
test_loss = loss_fn(pred=pred, y=batch['target'], loss_function_type=self.cfg['loss_function_type'])
# Convert pred and y from z-score to Mg/ha value to compute R^2
pred = re_convert_to_Mg_ha(self.z_info, z_components_arr=pred)
y = re_convert_to_Mg_ha(self.z_info, z_components_arr=batch['target'])
# Get loss of total AGB
tree_pred = pred[:, 0] + pred[:, 1] + pred[:, 2] + pred[:, 3]
tree_obs = y[:, 0] + y[:, 1] + y[:, 2] + y[:, 3]
# Calculate metrics for component
comp_list = ['bark', 'branch', 'foliage', 'wood']
idx_list = [0, 1, 2, 3]
test_metric_dict = dict()
for comp, idx in zip(comp_list, idx_list):
test_metric_dict[comp + "_r2"] = r2_score(preds=pred[:, idx], target=y[:, idx])
test_metric_dict[comp + "_rmse"] = torch.sqrt(mean_squared_error(preds=pred[:, idx], target=y[:, idx]))
# Calculate metrics for tree and overall
test_metric_dict['tree_r2'] = r2_score(preds=tree_pred, target=tree_obs)
test_metric_dict['overall_r2'] = r2_score(preds=pred, target=y, adjusted=0, multioutput='uniform_average')
test_metric_dict['tree_mse'] = torch.sqrt(mean_squared_error(preds=tree_pred, target=tree_obs))
test_metric_dict['overall_mse'] = torch.sqrt(mean_squared_error(preds=pred, target=y))
# Log metrics
self.log_dict(test_metric_dict, batch_size=self.cfg['batch_size'])
return test_loss
def configure_optimizers(self):
optimizer = torch.optim.AdamW(params=self.model.parameters(), lr=self.cfg['lr'],
weight_decay=self.cfg['weight_decay'])
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer,
T_0=self.cfg['cawr_t_0'],
T_mult=self.cfg['cawr_t_mult']
)
return [optimizer], [lr_scheduler]
issubclass(BiomassModel, pl.LightningModule)
# Static configs that does not change across trials
static_lightning_config = (
LightningConfigBuilder()
.module(cls=BiomassModel, **cfg)
.trainer(max_epochs=num_epochs, accelerator='gpu', logger=logger)
.fit_params(datamodule=data_module)
.checkpointing(monitor=target_metric,
save_top_k=0, #No models are saved during HP tuning
mode=direction)
.build()
)
# Searchable configs across different trials
searchable_lightning_config = (
LightningConfigBuilder()
.module(config={
"lr": tune.loguniform(1e-4, 1e-1),
})
.build()
)
# Make sure to also define an AIR CheckpointConfig here
# to properly save checkpoints in AIR format.
run_config = RunConfig(
checkpoint_config=CheckpointConfig(
num_to_keep=0,
checkpoint_score_attribute=target_metric,
checkpoint_score_order=direction,
),
)
scheduler = ASHAScheduler(max_t=num_epochs, grace_period=1, reduction_factor=2)
scaling_config = ScalingConfig(
num_workers=3, use_gpu=True, resources_per_worker={"CPU": 1, "GPU": 2}
)
# Define a base LightningTrainer without hyper-parameters for Tuner
lightning_trainer = LightningTrainer(
lightning_config=static_lightning_config,
scaling_config=scaling_config,
run_config=run_config,
)
scheduler = ASHAScheduler(max_t=num_epochs, grace_period=1, reduction_factor=2)
tuner = tune.Tuner(
lightning_trainer,
param_space={"lightning_config": searchable_lightning_config},
tune_config=tune.TuneConfig(
metric=target_metric,
mode=direction,
num_samples=n_trials,
scheduler=scheduler,
),
run_config=air.RunConfig(
name="ray_tune_test",
),
)
results = tuner.fit()
best_result = results.get_best_result(metric=target_metric, mode=direction)
print(best_result)