Long initialization time to initialize_session with large scale dataset

How severe does this issue affect your experience of using Ray?

  • Medium: It contributes to significant difficulty to complete my task, but I can work around it.

Issue

I am working on distributed training with Ray. The issue I have is very long job initialization time before the training starts when I use large scale dataset. For example, if I launch the job on a 4x8 cluster, then the initialization time would be 32 * 1 min, in which case all the GPUs are in idle status for 30mins during training session initialization.
In Short, the question is why does my job has sequential initialization to all ranks instead of executed on each rank in parallel?

Observation

  1. The whole initialization time is linear to the world_rank
  2. The initialization time of a single rank is linear to the dataset size. The initialization finishes instantly with small dev dataset. I think it might related to data metadata and dataset exec_plan size. If I increase the row number of my parquet file, the init time on each rank drops, because the total file number drops.
  3. The issue is caused by the initialize_session part from the backend_executor.py. Here is the code link
  4. We found the initialize_session on each actor from the worker group is called sequentially, not in parallel. Here is an obvious evidence from our monitoring result:

Setup

  • Platform: AWS + K8S + KubeRay
  • Framework: Ray Data (read_parquet) + Ray Train (TorchTrainer) + Pytorch-Lightning
  • Ray version: 2.20

Failed Solutions

  1. I thought the self.dataset_shards[index] might be the blocker because the class instance size become huge with large scale dataset. I tried put it to object store and load it while execution. It seems took even longer time.
  2. I thought the sequential initialization might from the object serialization and python GIL. So, I tried to use process pool to call the execute_single_async function, and the whole training process got stuck.

@Tairui_Wang Thanks for the question. However, the code link you showed actually runs in parallel. The for loop just creates a list of futures and immediately returns. The real workload is distributed to the 32 worker processes running in parallel.

The initialize_session normally returns in one to a few seconds in the 32 processes case. The figure above looks interesting, it might related to a misuse of ray data. It would be helpful if you can provide a mini repro of your code

@hpguo thanks so much for your reply. Here is my simplified code with some key configs. Pretty much following the ray doc. I can provide a runnable script if needed, but it requires some extra time

import ray    
import logging
import pytorch_lightning as pl
from ray.train.torch import TorchTrainer
from ray.train import CheckpointConfig as RayCheckpointConfig
from pyarrow import fs
from pyarrow._s3fs import AwsStandardS3RetryStrategy


logger = logging.getLogger(__name__)


class DataModule(pl.LightningDataModule):
    def train_dataloader(self):
        return self._get_torch_dataloaders()

    def _get_torch_dataloaders(self):
        data_shard = ray.train.get_dataset_shard("train")
        return data_shard.iter_torch_batches(
            prefetch_batches=1,
            batch_size=batch_size,
            collate_fn=my_collate,
            drop_last=True,
            local_shuffle_seed=100,
        )

def train_loop_per_worker():
        logger.info("** train_loop_per_worker **")
        # Normal Pytorch Module
        model = Model()
        train_config = {}
        train_config["strategy"] = ray.train.lightning.RayDDPStrategy(
            find_unused_parameters=False, static_graph=True
        )
        data_module = DataModule()
        trainer_shared_kwargs = {
            "datamodule": data_module,
        }
        trainer = pl.Trainer(**train_config)
        trainer = ray.train.lightning.prepare_trainer(trainer)
        trainer.fit(model, **trainer_shared_kwargs)


file_paths, total_data_count = get__s3_files()
override_num_blocks = total_data_count // 2
concurrency = 12
train_dataset = ray.data.read_parquet(
    file_paths,
    shuffle="files",
    concurrency=concurrency,
    override_num_blocks=override_num_blocks,
    ray_remote_args={"max_retries": 3, "retry_exceptions": True},
    filesystem=fs.S3FileSystem(region=S3_BUCKET_REGION,
                                retry_strategy=AwsStandardS3RetryStrategy(max_attempts=3)),
)
# Not sure whether data_preprocess can cause the issue
train_dataset = train_dataset.map(data_preprocess, concurrency=concurrency * 3)
ray_datasets = {
    "train": {
        "dataset": train_dataset,
        "count":total_data_count 
    }
}

execution_options = ray.data.ExecutionOptions()
execution_options.verbose_progress = False
execution_options.locality_with_output = True
execution_options.actor_locality_enabled = False
execution_options.preserve_order = False
ray.data.DataContext.get_current().enable_progress_bars = False
ray_data_config = ray.train.DataConfig(execution_options=execution_options)

scaling_config = ray.train.ScalingConfig(
    num_workers=32, # 4 * 8 cluster
    use_gpu=True,
    resources_per_worker={
        "CPU": 8,
        "GPU": 1,
    },
    placement_strategy="PACK",
)

run_config = ray.train.RunConfig(
    storage_path="/efs/ray_job",
    name="ray_train",
    checkpoint_config=RayCheckpointConfig(),
    failure_config=ray.train.FailureConfig(max_failures=3),
    verbose=0,
)

ray_trainer = TorchTrainer(
        train_loop_per_worker=train_loop_per_worker,
        datasets=ray_datasets,
        dataset_config=ray_data_config,
        scaling_config=scaling_config,
        run_config=run_config,
    )
ray_trainer.fit()

@Tairui_Wang Do you know what line of code the driver and the workers are on during that 32 minute idle window? Have you narrowed it down to the worker actors hanging on that initialize_session call?

You can take a look at the stack trace of the driver and RayTrainWorker actors here in the Ray Dashboard: Ray Dashboard — Ray 2.41.0