Training time not decreasing with more workers

None: I’m just curious or want clarification.

Low: Annoying but doesn’t hinder my work.

Medium: Significantly affects my productivity but can find a workaround.

High: Completely blocks me.

2. Environment:

  • Ray version: 2.37.0
  • Python version: 3.11.0
  • OS: linux
  • Cloud/Infrastructure: databricks
  • Other libs/tools (if relevant): lightning.pytorch 2.2.1

3. What happened vs. what you expected:

  • Expected: When I added more workers to my ray cluster, I expected the training time to go down linearly
  • Actual: Training time stayed consistent

I’m a data engineer trying to help with an ML task to distribute the workload our pytorch lightning training process. We use ray for some other tasks, so I tried following the steps here to distribute it on a ray cluster on databricks.

Here’s a sample of the code. Unfortunately it’s not possible to give a full reproducible example, but I’m wondering if I’m missing something obvious that might be causing some bottleneck.

import ray.train.lightning as ray_lightning
import lightning.pytorch as pl

NUM_WORKERS = 2 # obviously changing this between test runs

BatchedTrainingData = Iterable[dict[str, torch.Tensor]]

PREFETCH_BATCHES = 100

def get_train_loader(train_data, batch_size: int) -> BatchedTrainingData:
    return cast(
        BatchedTrainingData,
        train_data.materialize().iter_torch_batches(
            batch_size=batch_size,
            device="cuda",
            prefetch_batches=PREFETCH_BATCHES,
            local_shuffle_buffer_size=batch_size*4,
        ),
    )

def get_validation_loader(
    validation_data, batch_size: int, device: Literal["cpu", "cuda"] = "cuda"
) -> BatchedTrainingData:
    return cast(
        BatchedTrainingData,
        validation_data.materialize().iter_torch_batches(
            prefetch_batches=PREFETCH_BATCHES, batch_size=batch_size, device=device
        ),
    )

def distributed_train_func(config: dict):
    import torch._dynamo

    torch._dynamo.config.suppress_errors = True
    data_dir = config["data_dir"]

    train_data = ray.data.read_parquet(data_dir / "train.parquet")
    validation_data = ray.data.read_parquet(data_dir / "validation.parquet")
    batch_size = 1024 * NUM_WORKERS
    train_loader = get_train_loader(train_data, batch_size=batch_size)
    val_loader = get_validation_loader(validation_data, batch_size=batch_size)

    model = Model() # our custom model that inherits from pl.LightningModule

    trainer = pl.Trainer(
        deterministic=False,
        gradient_clip_val=5,
        max_epochs=2,
        callbacks=[
            pl.callbacks.EarlyStopping(
                monitor="loss/validation", patience=6, min_delta=0.00005
            ),
            pl.callbacks.LearningRateMonitor(logging_interval="epoch"),
            ray_lightning.RayTrainReportCallback(),
        ],
        num_sanity_val_steps=0,
        precision="bf16-mixed",
        strategy=ray_lightning.RayDDPStrategy(find_unused_parameters=True),
        plugins=[ray_lightning.RayLightningEnvironment()],
        enable_checkpointing=False,
    )
    trainer = ray_lightning.prepare_trainer(trainer)
    trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)


from ray.train import ScalingConfig

scaling_config = ScalingConfig(num_workers=NUM_WORKERS, use_gpu=True)
config = {
    "data_dir": trainer_path,
}

trainer = TorchTrainer(
    distributed_train_func,
    scaling_config=scaling_config,
    run_config=ray.train.RunConfig(storage_path=str(trainer_path / "ray_storage")),
    train_loop_config=config,
)
result: ray.train.Result = trainer.fit()
  • Looking at the ray dashboard, I can see all nodes are getting utilized (including the GPUs)
  • I’ve tried tweaking the batch size. If I increase the batch, it just leads to a proportional increase in the amount of time to process each batch
  • I tried increasing PREFETCH_BATCHES, but that didn’t change anything
  • I’ve also tried not calling .materialize() in the get_train/validation_loader functions, but that didn’t change anything

If anyone has any advice, I’d really appreciate it! Many thanks :slight_smile:

I eventually found the answer in the docs here.
I think it could be helpful if there was a pointer to this in the pytorch lighting docs, as there is in the pytorch ones. As they are, the pytorch lightning docs made it seem like you could provide any data loader and it would work automatically.

Thanks for catching this!

Would you be interested in opening a PR to update the user guide?