Ray tune method 'wait' of 'ray_raylet.Core Worker' objects takes a very long time

I am using hydra to manage configs (the search space passed as config, and the rest of the architecture passed as arch_cfg), hence the use of instantiate to initialise objects

I am also using wandb_mixin with the train function. I’ve had the wandb_mixin in my regular pytorch script and that has no performance issues as with pytorch lightning.

this is my overall train function:

@wandb_mixin
def pl_train_func(config, arch_cfg, checkpoint_dir=None):
    project_dir = os.getenv('TUNE_ORIG_WORKING_DIR')

    # wandb is already initialised with wandb_mixin in ray, hence just need to pass in the wrapper into pl
    wandb_logger = WandbLogger()

    config = OmegaConf.create(config)
    arch_cfg = OmegaConf.merge(arch_cfg, config)
    write_conf(arch_cfg, "config.yaml")

    wandb_logger.experiment.log_artifact("config.yaml", type="config")

    _config = arch_cfg.copy()

    with open_struct(_config):
        _config.run.pop("config")
        _config = OmegaConf.to_container(_config)
        _config = flatten_dict(_config, delimiter='-')

    wandb_logger.experiment.config.update(_config, allow_val_change=True)

    if arch_cfg.get("seed"):
        seed_everything(arch_cfg.seed, workers=True)

    checkpoint_name = arch_cfg.get("checkpoint_name", "model_checkpoint.ckpt")

    with change_directory(project_dir):

        datamodule= instantiate(arch_cfg.datamodule)

        output_size = len(datamodule.categories)
        channel_n = datamodule.channel_n
        net = instantiate(arch_cfg.model.net, output_size=output_size, channel_n=channel_n)

        model= instantiate(arch_cfg.model, net=net)

        wandb_logger.watch(model, log="all", log_freq=1)

        wandb_logger.experiment.define_metric(name=arch_cfg.monitor, summary=arch_cfg.mode)

        tune_callback = TuneReportCheckpointCallback(filename=checkpoint_name, on="validation_end",
                                                     keep_pth=arch_cfg.get("enable_tune_cktp", True))

        resume_checkpoint = None
        if checkpoint_dir:
            resume_checkpoint = os.path.join(checkpoint_dir, checkpoint_name)

        callbacks = [tune_callback]
        if arch_cfg.pl_trainer.get("callbacks"):
            callbacks = list(instantiate(arch_cfg.pl_trainer.callbacks)) + callbacks

        trainer: Trainer = instantiate(arch_cfg.pl_trainer, callbacks=callbacks, enable_progress_bar=False,
                                       logger=wandb_logger)

    trainer.fit(model=model, datamodule=datamodule, ckpt_path=resume_checkpoint)

    wandb_logger.experiment.unwatch(model)

with the topic on putting in checkpoints, I’ve actually changed the pytorch lightning TuneCheckpointCallback a little bit so that I have the option to turn off checkpointing completely (as I couldn’t do so with this issue). So this is now the _TuneCheckpointCallback class that TuneCheckpointCallback uses:

class _TuneCheckpointCallback(TuneCallback):
    """PyTorch Lightning checkpoint callback
    Saves checkpoints after each validation step.
    Checkpoint are currently not registered if no ``tune.report()`` call
    is made afterwards. Consider using ``TuneReportCheckpointCallback``
    instead.
    Args:
        filename: Filename of the checkpoint within the checkpoint
            directory. Defaults to "checkpoint".
        on: When to trigger checkpoint creations. Must be one of
            the PyTorch Lightning event hooks (less the ``on_``), e.g.
            "batch_start", or "train_end". Defaults to "validation_end".
    """

    def __init__(
            self, filename: str = "checkpoint", on="validation_end", keep_pth=True):
        super(_TuneCheckpointCallback, self).__init__(on)
        self._filename = filename
        self.keep_pth = keep_pth

    def _handle(self, trainer: Trainer, pl_module: LightningModule):
        if trainer.sanity_checking:
            return
        step = f"epoch-{trainer.current_epoch}-step-{trainer.global_step}"
        with tune.checkpoint_dir(step=step) as checkpoint_dir:
            if self.keep_pth:
                trainer.save_checkpoint(os.path.join(checkpoint_dir, self._filename), weights_only=True)

I’ve had to implement something similar in the regular pytorch script as well.

I save all checkpoints/logs on the local disk not to a cloud