Ray Tune is slowing down lightning model performance by 3x

As the title suggests, when using ray tune, the epoch takes roughly 3 hours, where it plainly takes an hour to complete, here’s the flame graph of the 20 seconds runtime and ~6000 records collected.

Here is the trainer:

trainer = Trainer(
    logger=logger,
    accelerator="gpu",
    devices="auto",
    strategy=DDPStrategy(
        find_unused_parameters=False, gradient_as_bucket_view=True
    ),
    max_epochs=args["epochs"],
    callbacks=callbacks,
    auto_lr_find=False,  # will be manually set
)

And the ray tuning

resources_per_trial = {
    "cpu": args.cpus,
    "gpu": args.gpus,
}

callbacks: list[Callback] = [
    TuneReportCheckpointCallback(
        metrics=["valid_loss"],
        filename="checkpoint",
        on=["validation_end", "train_end"],
    )
]
trainable = tune.with_parameters(main, callbacks=callbacks)
tune.run(
    trainable,
    name=args.name,
    local_dir=args.log_dir,
    config=config,
    resources_per_trial=resources_per_trial,
    resume=args.resume,
)

hey thanks for doing the profiling! Do you mind sharing a simple code reproduction so that we can test it and fix it?

thanks for the quick reply @rliaw, I’ve edited the post to include the initialization of the trainer and the ray tuning part, I know it would be easier to include a verifiable and complete example, but the model is a bit complex. Do you think we can conclude anything from the above code along with the flame graph and following dump?

Also, although I am decorating trainable with @wandb_mixin, I am using WandbLogger from PyTorch Lightning, because it seems that ray doesn’t expose any custom implementation of WandbLogger so that I’d pass it to the Trainer, so I am only using wandb_mixin to skip the initialization lines.

It looks like it’s wandb’s code that is slowing things down.

I know you definitely want to use wandb but do you mind just removing the wandb mixin for now and see what happens?

I tried actually, here’s the dump without wandb_mixin and WandbLogger

on a separate note, I optimized pad_adj which was taking much time too, that’s irrelevant, just saying because it was referenced in the previous dump.

Interesting… wonder if there’s some port conflict due to ray / pytorch…

Seems related - Launching two processes causes hanging · Issue #50669 · pytorch/pytorch · GitHub

cc @amogkam @Jiao_Dong