What is the correct way of using get_dataset_shard?

1. Severity of the issue: (select one)
Low: Annoying but doesn’t hinder my work.

2. Environment:

  • Ray version: 2.49.0
  • Python version: 3.10
  • OS: Ubuntu 24.04
  • Cloud/Infrastructure: AWS
  • Other libs/tools (if relevant):

Where should I call get_dataset_shard in training loop worker function? - Inside epoch iterator or outside epoch iterator.

In following example, get_dataset_shard is called before iterating over epoch

def training_loop_per_worker()
    ...
    # === Get Data ===
    train_ds = get_dataset_shard("train")
    ...
    for epoch in range(config["train_epochs"]):
        ...
    ...

In another example, it is called inside epoch iterator.

def train_func(config):
    ...
    for epoch in range(config["epochs"]):
        ...
        train_dataset_shard = train.get_dataset_shard("train")