Issues with Tuning Batch Size Ray Air Pytorch Lightning

The new ray 2.7 update provided a solution to this problem where the data module can be included in the training function! This solves my problem, so this issue can be closed.

def train_func(tune_cfg):
    current_directory = os.getcwd()
    print("Current ray working directory:", current_directory)

    # Load hps and update with ray tune config
    with open(r'D:/Sync/RQ2/Analysis/config.yaml', "r") as yamlfile:
        cfg = yaml.load(yamlfile, Loader=yaml.FullLoader)

    cfg.update(tune_cfg)

    # Create a Lightning model
    model = LitModel(cfg)

    # Create a Lighting Trainer
    trainer = pl.Trainer(
        max_epochs=cfg['num_epochs'],
        devices="auto",
        accelerator="auto",
        strategy=RayDDPStrategy(),
        plugins=[RayLightningEnvironment()],
        callbacks=[RayTrainReportCallback()],
        enable_progress_bar=False,
    )

    # Validate lightning trainer configuration
    trainer = prepare_trainer(trainer)

    # Build your datasets on each worker
    data_module = BiomassDataModule(cfg=cfg)

    # Train model
    trainer.fit(model, datamodule=data_module)
1 Like