Ray tune method 'wait' of 'ray_raylet.Core Worker' objects takes a very long time

Hi, I am running ray tune 1.13.0 with the pytorch lightning (1.6.4) integration and wandb mixin as integrated per ray tune documentation. I am running this on windows 10 with 2 GPUs (RTX 2080 ti and a Quadro P1000) and a Xeon E5-2630 v4 CPU with 64gb of RAM

I am experiencing incredibly long run times with this setup compared to previous with ray tune and just pytorch. Where there is at least 5 to 6 fold difference between them (4 hours compared to half an hour)

On using cProfile to profile the times, it found the following:

I do not experience similar run times on another windows 11 machine with a single 1050Ti GPU nor when I run on linux.

Hence was wondering for any help or direction to go with solving this

this is partly what the Call graph from cProfile looks like:

hi @stephano41, do you mind sharing more details to reproduce the issue, like code snippets ?

My rough read on this issue is ray.get() is taking majority of the time during your run through Ray object store. If your trial / model checkpoint that are being synced and large it could be the primary bottleneck.

cc: @kai @xwjiang2010

I am using hydra to manage configs (the search space passed as config, and the rest of the architecture passed as arch_cfg), hence the use of instantiate to initialise objects

I am also using wandb_mixin with the train function. I’ve had the wandb_mixin in my regular pytorch script and that has no performance issues as with pytorch lightning.

this is my overall train function:

@wandb_mixin
def pl_train_func(config, arch_cfg, checkpoint_dir=None):
    project_dir = os.getenv('TUNE_ORIG_WORKING_DIR')

    # wandb is already initialised with wandb_mixin in ray, hence just need to pass in the wrapper into pl
    wandb_logger = WandbLogger()

    config = OmegaConf.create(config)
    arch_cfg = OmegaConf.merge(arch_cfg, config)
    write_conf(arch_cfg, "config.yaml")

    wandb_logger.experiment.log_artifact("config.yaml", type="config")

    _config = arch_cfg.copy()

    with open_struct(_config):
        _config.run.pop("config")
        _config = OmegaConf.to_container(_config)
        _config = flatten_dict(_config, delimiter='-')

    wandb_logger.experiment.config.update(_config, allow_val_change=True)

    if arch_cfg.get("seed"):
        seed_everything(arch_cfg.seed, workers=True)

    checkpoint_name = arch_cfg.get("checkpoint_name", "model_checkpoint.ckpt")

    with change_directory(project_dir):

        datamodule= instantiate(arch_cfg.datamodule)

        output_size = len(datamodule.categories)
        channel_n = datamodule.channel_n
        net = instantiate(arch_cfg.model.net, output_size=output_size, channel_n=channel_n)

        model= instantiate(arch_cfg.model, net=net)

        wandb_logger.watch(model, log="all", log_freq=1)

        wandb_logger.experiment.define_metric(name=arch_cfg.monitor, summary=arch_cfg.mode)

        tune_callback = TuneReportCheckpointCallback(filename=checkpoint_name, on="validation_end",
                                                     keep_pth=arch_cfg.get("enable_tune_cktp", True))

        resume_checkpoint = None
        if checkpoint_dir:
            resume_checkpoint = os.path.join(checkpoint_dir, checkpoint_name)

        callbacks = [tune_callback]
        if arch_cfg.pl_trainer.get("callbacks"):
            callbacks = list(instantiate(arch_cfg.pl_trainer.callbacks)) + callbacks

        trainer: Trainer = instantiate(arch_cfg.pl_trainer, callbacks=callbacks, enable_progress_bar=False,
                                       logger=wandb_logger)

    trainer.fit(model=model, datamodule=datamodule, ckpt_path=resume_checkpoint)

    wandb_logger.experiment.unwatch(model)

with the topic on putting in checkpoints, I’ve actually changed the pytorch lightning TuneCheckpointCallback a little bit so that I have the option to turn off checkpointing completely (as I couldn’t do so with this issue). So this is now the _TuneCheckpointCallback class that TuneCheckpointCallback uses:

class _TuneCheckpointCallback(TuneCallback):
    """PyTorch Lightning checkpoint callback
    Saves checkpoints after each validation step.
    Checkpoint are currently not registered if no ``tune.report()`` call
    is made afterwards. Consider using ``TuneReportCheckpointCallback``
    instead.
    Args:
        filename: Filename of the checkpoint within the checkpoint
            directory. Defaults to "checkpoint".
        on: When to trigger checkpoint creations. Must be one of
            the PyTorch Lightning event hooks (less the ``on_``), e.g.
            "batch_start", or "train_end". Defaults to "validation_end".
    """

    def __init__(
            self, filename: str = "checkpoint", on="validation_end", keep_pth=True):
        super(_TuneCheckpointCallback, self).__init__(on)
        self._filename = filename
        self.keep_pth = keep_pth

    def _handle(self, trainer: Trainer, pl_module: LightningModule):
        if trainer.sanity_checking:
            return
        step = f"epoch-{trainer.current_epoch}-step-{trainer.global_step}"
        with tune.checkpoint_dir(step=step) as checkpoint_dir:
            if self.keep_pth:
                trainer.save_checkpoint(os.path.join(checkpoint_dir, self._filename), weights_only=True)

I’ve had to implement something similar in the regular pytorch script as well.

I save all checkpoints/logs on the local disk not to a cloud

@stephano41 Could you also link how you call tune.run? I am trying to reproduce your workflow locally. Thanks!!

Thank you for your help!

Here is the code for calling tune.run:

@hydra.main(config_path='conf/', config_name='tune', version_base='1.2')
def main(config)
output_dir = Path("<some_full_output_dir">)

analysis: ExperimentAnalysis = tune.run(tune.with_parameters(pl_train_func, arch_cfg=config),
                                            **(instantiate(config.run, _convert_="partial")),
                                            local_dir=output_dir.parent,
                                            name=output_dir.name)

pl_train_func is described as above

config.run contains the following parameters (which is passed to tune.run as a dictionary with ** operator):

metric: val/f1
mode: max
verbose: 2
keep_checkpoints_num: 5
checkpoint_score_attr: val/f1
num_samples: 100
scheduler:
  _target_: ray.tune.schedulers.ASHAScheduler
  max_t: 50
  grace_period: 1
  reduction_factor: 2
config:
  wandb:
    #<wandb_config_including_project_name_group_name>
  #<other_hyperparameter_config>

other custom utility functions used:

@contextlib.contextmanager
def open_struct(config):
    OmegaConf.set_struct(config, False)
    try:
        yield
    finally:
        OmegaConf.set_struct(config, True)


@contextlib.contextmanager
def change_directory(path):
    """Changes working directory and returns to previous on_event exit."""
    prev_cwd = Path.cwd()
    if path:
        os.chdir(path)
    try:
        yield
    finally:
        os.chdir(prev_cwd)


def write_yaml(content, fname):
    with fname.open('wt') as handle:
        yaml.dump(content, handle, indent=2, sort_keys=False)


def write_conf(config, save_path):
    save_path = Path(save_path)
    save_path.parent.mkdir(parents=True, exist_ok=True)
    config_dict = OmegaConf.to_container(config, resolve=True)
    write_yaml(config_dict, save_path)