How to get dataset shard size in each train worker

1. Severity of the issue: (select one)
Low: Annoying but doesn’t hinder my work.

2. Environment:

  • Ray version: 2.46.0
  • Python version: 3.12
  • OS: wsl2 ubuntu

3. What happened vs. what you expected:

  • Expected: handy method to get the total number of rows in train_dataset
  • Actual: calculate by hand
trainer = ray.train.torch.TorchTrainer(
    train_func,
    datasets={"train": load_dataset("xxx")},
    train_loop_config=...,
    scaling_config=...,
)

def train_func(config):
    train_dataset = ray.train.get_dataset_shard("train")
    total_rows = ? # how to get the total number of rows in train_dataset?

ray.train.get_dataset_shardreturns a DataIterator and I cannot get the size of it. Is there any way to get the total number of rows inside this shard?