I am using Pytorch lightning and ray air training based on this tutorial.
But I am running into the issue of passing the config to the data module. The suggested config parameters are passed on to my lightning trainer no problem but the data module config does not seem to be updated. I can’t seem to find any information about how to do this with Ray Air online…
What would be the best way to tune batch size when using lightning data modules?
Ray code based on this tutorial: Using PyTorch Lightning with Tune — Ray 2.8.0
from ray.train.lightning import LightningTrainer, LightningConfigBuilder
from ray import air, tune
from ray.train.torch import TorchConfig
from ray.air.config import ScalingConfig
from ray.tune.schedulers import ASHAScheduler
import yamlfrom dataset import BiomassDataModule
from trainer import LitModeldef tune_model(target_metric=“val_loss”, direction=“min”, n_trials=100):
logger = None with open("config.yaml", "r") as yamlfile: cfg = yaml.load(yamlfile, Loader=yaml.FullLoader) data_module = BiomassDataModule(cfg=cfg) # Static configs that does not change across trials static_lightning_config = ( LightningConfigBuilder() .module(cls=LitModel, config=cfg['hp']) .trainer(max_epochs=cfg['num_epochs'], accelerator='gpu', logger=logger, limit_train_batches=cfg['partition_train'], precision=cfg['precision'] ) .fit_params(datamodule=data_module) .checkpointing(monitor=target_metric, save_top_k=1, mode=direction, ) .build() ) # Searchable configs across different trials searchable_lightning_config = ( LightningConfigBuilder() .module(config={ "lr": tune.loguniform(1e-6, 1e-2), "batch_size": tune.choice([62, 96]), #tune.choice([8, 16, 32, 64, 128]), "dropout_1": tune.uniform(0.0, 0.6), "dropout_final": tune.uniform(0.0, 0.6), }) .build() ) scaling_config = ScalingConfig( num_workers=1, use_gpu=True, resources_per_worker={"CPU": 1, "GPU": 2} ) # Define a base LightningTrainer without hyper-parameters for Tuner lightning_trainer = LightningTrainer( lightning_config=static_lightning_config, scaling_config=scaling_config, torch_config=TorchConfig(backend='gloo') ) scheduler = ASHAScheduler(max_t=cfg['num_epochs'], grace_period=1, reduction_factor=2) tuner = tune.Tuner( lightning_trainer, param_space={"lightning_config": searchable_lightning_config}, tune_config=tune.TuneConfig( metric=target_metric, mode=direction, num_samples=n_trials, scheduler=scheduler, ), run_config=air.RunConfig(storage_path="./"), ) results = tuner.fit() best_result = results.get_best_result(metric=target_metric, mode=direction) print(best_result)
if name == “main”:
tune_model(n_trials=3)