The new ray 2.7 update provided a solution to this problem where the data module can be included in the training function! This solves my problem, so this issue can be closed.
def train_func(tune_cfg):
current_directory = os.getcwd()
print("Current ray working directory:", current_directory)
# Load hps and update with ray tune config
with open(r'D:/Sync/RQ2/Analysis/config.yaml', "r") as yamlfile:
cfg = yaml.load(yamlfile, Loader=yaml.FullLoader)
cfg.update(tune_cfg)
# Create a Lightning model
model = LitModel(cfg)
# Create a Lighting Trainer
trainer = pl.Trainer(
max_epochs=cfg['num_epochs'],
devices="auto",
accelerator="auto",
strategy=RayDDPStrategy(),
plugins=[RayLightningEnvironment()],
callbacks=[RayTrainReportCallback()],
enable_progress_bar=False,
)
# Validate lightning trainer configuration
trainer = prepare_trainer(trainer)
# Build your datasets on each worker
data_module = BiomassDataModule(cfg=cfg)
# Train model
trainer.fit(model, datamodule=data_module)