Issues with Tuning Batch Size Ray Air Pytorch Lightning

I am using Pytorch lightning and ray air training based on this tutorial.

But I am running into the issue of passing the config to the data module. The suggested config parameters are passed on to my lightning trainer no problem but the data module config does not seem to be updated. I can’t seem to find any information about how to do this with Ray Air online…

What would be the best way to tune batch size when using lightning data modules?

Ray code based on this tutorial: Using PyTorch Lightning with Tune — Ray 2.8.0

from ray.train.lightning import LightningTrainer, LightningConfigBuilder
from ray import air, tune
from ray.train.torch import TorchConfig
from ray.air.config import ScalingConfig
from ray.tune.schedulers import ASHAScheduler
import yaml

from dataset import BiomassDataModule
from trainer import LitModel

def tune_model(target_metric=“val_loss”, direction=“min”, n_trials=100):

logger = None

with open("config.yaml", "r") as yamlfile:
    cfg = yaml.load(yamlfile, Loader=yaml.FullLoader)

data_module = BiomassDataModule(cfg=cfg)

# Static configs that does not change across trials
static_lightning_config = (
    LightningConfigBuilder()
    .module(cls=LitModel, config=cfg['hp'])
    .trainer(max_epochs=cfg['num_epochs'], accelerator='gpu', logger=logger,
             limit_train_batches=cfg['partition_train'],
             precision=cfg['precision']
             )
    .fit_params(datamodule=data_module)
    .checkpointing(monitor=target_metric,
                   save_top_k=1,
                   mode=direction,
                   )
    .build()
)

# Searchable configs across different trials
searchable_lightning_config = (
    LightningConfigBuilder()
    .module(config={
        "lr": tune.loguniform(1e-6, 1e-2),
        "batch_size": tune.choice([62, 96]), #tune.choice([8, 16, 32, 64, 128]),
        "dropout_1": tune.uniform(0.0, 0.6),
        "dropout_final": tune.uniform(0.0, 0.6),
    })
    .build()
)

scaling_config = ScalingConfig(
    num_workers=1, use_gpu=True, resources_per_worker={"CPU": 1, "GPU": 2}
)

# Define a base LightningTrainer without hyper-parameters for Tuner
lightning_trainer = LightningTrainer(
    lightning_config=static_lightning_config,
    scaling_config=scaling_config,
    torch_config=TorchConfig(backend='gloo')
)

scheduler = ASHAScheduler(max_t=cfg['num_epochs'], grace_period=1, reduction_factor=2)

tuner = tune.Tuner(
    lightning_trainer,
    param_space={"lightning_config": searchable_lightning_config},
    tune_config=tune.TuneConfig(
        metric=target_metric,
        mode=direction,
        num_samples=n_trials,
        scheduler=scheduler,
    ),
    run_config=air.RunConfig(storage_path="./"),
)

results = tuner.fit()
best_result = results.get_best_result(metric=target_metric, mode=direction)
print(best_result)

if name == “main”:

tune_model(n_trials=3)

The new ray 2.7 update provided a solution to this problem where the data module can be included in the training function! This solves my problem, so this issue can be closed.

def train_func(tune_cfg):
    current_directory = os.getcwd()
    print("Current ray working directory:", current_directory)

    # Load hps and update with ray tune config
    with open(r'D:/Sync/RQ2/Analysis/config.yaml', "r") as yamlfile:
        cfg = yaml.load(yamlfile, Loader=yaml.FullLoader)

    cfg.update(tune_cfg)

    # Create a Lightning model
    model = LitModel(cfg)

    # Create a Lighting Trainer
    trainer = pl.Trainer(
        max_epochs=cfg['num_epochs'],
        devices="auto",
        accelerator="auto",
        strategy=RayDDPStrategy(),
        plugins=[RayLightningEnvironment()],
        callbacks=[RayTrainReportCallback()],
        enable_progress_bar=False,
    )

    # Validate lightning trainer configuration
    trainer = prepare_trainer(trainer)

    # Build your datasets on each worker
    data_module = BiomassDataModule(cfg=cfg)

    # Train model
    trainer.fit(model, datamodule=data_module)
1 Like