1. Severity of the issue: (select one)
Low: Annoying but doesn’t hinder my work.
2. Environment:
- Ray version: 2.46.0
- Python version: 3.12
- OS: wsl2 ubuntu
3. What happened vs. what you expected:
- Expected: handy method to get the total number of rows in train_dataset
- Actual: calculate by hand
trainer = ray.train.torch.TorchTrainer(
train_func,
datasets={"train": load_dataset("xxx")},
train_loop_config=...,
scaling_config=...,
)
def train_func(config):
train_dataset = ray.train.get_dataset_shard("train")
total_rows = ? # how to get the total number of rows in train_dataset?
ray.train.get_dataset_shardreturns a DataIterator and I cannot get the size of it. Is there any way to get the total number of rows inside this shard?