None: I’m just curious or want clarification.
Low: Annoying but doesn’t hinder my work.
Medium: Significantly affects my productivity but can find a workaround.
High: Completely blocks me.
2. Environment:
- Ray version: 2.37.0
- Python version: 3.11.0
- OS: linux
- Cloud/Infrastructure: databricks
- Other libs/tools (if relevant): lightning.pytorch 2.2.1
3. What happened vs. what you expected:
- Expected: When I added more workers to my ray cluster, I expected the training time to go down linearly
- Actual: Training time stayed consistent
I’m a data engineer trying to help with an ML task to distribute the workload our pytorch lightning training process. We use ray for some other tasks, so I tried following the steps here to distribute it on a ray cluster on databricks.
Here’s a sample of the code. Unfortunately it’s not possible to give a full reproducible example, but I’m wondering if I’m missing something obvious that might be causing some bottleneck.
import ray.train.lightning as ray_lightning
import lightning.pytorch as pl
NUM_WORKERS = 2 # obviously changing this between test runs
BatchedTrainingData = Iterable[dict[str, torch.Tensor]]
PREFETCH_BATCHES = 100
def get_train_loader(train_data, batch_size: int) -> BatchedTrainingData:
return cast(
BatchedTrainingData,
train_data.materialize().iter_torch_batches(
batch_size=batch_size,
device="cuda",
prefetch_batches=PREFETCH_BATCHES,
local_shuffle_buffer_size=batch_size*4,
),
)
def get_validation_loader(
validation_data, batch_size: int, device: Literal["cpu", "cuda"] = "cuda"
) -> BatchedTrainingData:
return cast(
BatchedTrainingData,
validation_data.materialize().iter_torch_batches(
prefetch_batches=PREFETCH_BATCHES, batch_size=batch_size, device=device
),
)
def distributed_train_func(config: dict):
import torch._dynamo
torch._dynamo.config.suppress_errors = True
data_dir = config["data_dir"]
train_data = ray.data.read_parquet(data_dir / "train.parquet")
validation_data = ray.data.read_parquet(data_dir / "validation.parquet")
batch_size = 1024 * NUM_WORKERS
train_loader = get_train_loader(train_data, batch_size=batch_size)
val_loader = get_validation_loader(validation_data, batch_size=batch_size)
model = Model() # our custom model that inherits from pl.LightningModule
trainer = pl.Trainer(
deterministic=False,
gradient_clip_val=5,
max_epochs=2,
callbacks=[
pl.callbacks.EarlyStopping(
monitor="loss/validation", patience=6, min_delta=0.00005
),
pl.callbacks.LearningRateMonitor(logging_interval="epoch"),
ray_lightning.RayTrainReportCallback(),
],
num_sanity_val_steps=0,
precision="bf16-mixed",
strategy=ray_lightning.RayDDPStrategy(find_unused_parameters=True),
plugins=[ray_lightning.RayLightningEnvironment()],
enable_checkpointing=False,
)
trainer = ray_lightning.prepare_trainer(trainer)
trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)
from ray.train import ScalingConfig
scaling_config = ScalingConfig(num_workers=NUM_WORKERS, use_gpu=True)
config = {
"data_dir": trainer_path,
}
trainer = TorchTrainer(
distributed_train_func,
scaling_config=scaling_config,
run_config=ray.train.RunConfig(storage_path=str(trainer_path / "ray_storage")),
train_loop_config=config,
)
result: ray.train.Result = trainer.fit()
- Looking at the ray dashboard, I can see all nodes are getting utilized (including the GPUs)
- I’ve tried tweaking the batch size. If I increase the batch, it just leads to a proportional increase in the amount of time to process each batch
- I tried increasing PREFETCH_BATCHES, but that didn’t change anything
- I’ve also tried not calling
.materialize()
in theget_train/validation_loader
functions, but that didn’t change anything
If anyone has any advice, I’d really appreciate it! Many thanks