Training loop stuck at StreamSplitDataIterator

1. Severity of the issue: (select one)
High: Completely blocks me.

2. Environment:

  • Ray version: 2.49.0
  • Python version: 3.10.12
  • OS: Ubuntu 24.04
  • Cloud/Infrastructure: AWS
  • Other libs/tools (if relevant):

3. What happened vs. what you expected:

  • Expected: Ray training continue for 10 epoch
  • Actual: Ray train stuck at the end of epoch

I’m using Fine-tuning a face mask detection model with Faster R-CNN — Ray 2.49.1 example to train a RCNN model with minor modification.

If I use more than one workers in scaling_config, the training loop gets stuck at following line.

The behavior is bit random. Sometimes it is able to finish 2 epochs and then stuck at the following warning and sometimes it is stuck after finishing 1 epoch.

...
...
StreamSplitDataIterator(epoch=1, split=1) blocked waiting on other clients for more than 30s. All clients must read from the DataIterator splits at the same time. This warning will not be printed again for this epoch.

Here is pseudo code

train_dataset = ray.data.read_parquet(dataset_info_file)
train_dataset = train_dataset.map(ReadDataset, concurrency=2, num_cpus=4)

train_loop_config = {"num_epochs": 10, "lr": 0.01, "batch_size": 4, "prefetch_batches":1, "enable_mlflow": False, "run_name" : run_name}

run_config=ray.train.RunConfig(
    storage_path=s3_storage_path,
    name=run_name,
)

scaling_config = ray.train.ScalingConfig(num_workers=2, use_gpu=True, resources_per_worker={"CPU": 3, "GPU" : 1})

trainer = TorchTrainer(
    train_loop_per_worker=train_loop_per_worker,
    train_loop_config=train_loop_config,
    datasets={"train": train_dataset},
    scaling_config=scaling_config,
    run_config=run_config
    )

result = trainer.fit()

I’m stuck here a little. Can someone guide me on debugging this problem?

Thank you.

I found a problem in my code.

My code originally

...
if ray.train.get_context().get_world_rank() == 0:
    ...
    ray.train.report(metrics, checkpoint=checkpoint)
...

Documentation for ray.train.report states clearly that it should be called from all workers.

ray.train.report was not called from non zero rank workers.

Following change fixed my problem.

...
if ray.train.get_context().get_world_rank() == 0:
    ...
    ray.train.report(metrics, checkpoint=checkpoint)
    ...
else:
    ray.train.report(metrics, checkpoint=None)
...