[SGD] [Tune] Issue with ray.util.sgd.data.Dataset API

I have been using ray.util.sgd.data.Dataset with pytorch . The sole purpose of this API is to ingest huge data into ray and train it.

In my case when trying to train the data .It displays

{‘num_samples’: 351, ‘epoch’: 1.0, ‘batch_count’: 1.0, ‘train_loss’: 1.201025366783142, ‘last_train_loss’: 1.201025366783142}
{‘num_samples’: 0, ‘epoch’: 2.0, ‘batch_count’: 0.0}
{‘num_samples’: 0, ‘epoch’: 3.0, ‘batch_count’: 0.0}

All the samples are loaded into the first epoch always. The other epochs are all empty.

My code snippet

   MyTrainingOperator = TrainingOperator.from_creators(
   model_creator= model_creator, optimizer_creator=optimizer_creator,
   loss_creator= torch.nn.CrossEntropyLoss, scheduler_creator=scheduler_creator,

  trainer = TorchTrainer(
  config={"batch_size": 64}

 # fetching  values from database using Dataset.
db_dataset = fetch_values_from_database()

for i in range(500):
   # Train for another epoch using the dataset
    stats = trainer.train(dataset=db_dataset , num_steps=200)

I can share more details on the dataset. Thanks in advance.

@SumanthDatta hmm this is odd. could you provide a small reproducible example so I can try it out? Thanks.

Sure I will try to provide an example, in my case the data is fetched from a database. The fetching part is done using parallel iterator. Parallel iterator is passed to a dataset class.

Does the same problem happen for you with the basic example on the docs? It would really help to debug if there is a small, reproducible example I can try out.

@amogkam , i am sharing the code in the below link. I was able to replicate the issue. I might be dong some thing wrong here. Please help in identifying the issue.


The dataset used is


Download the csv and provide path in the code. Thanks in advance. If this problem is solved. It will be great help to me.

@amogkam , any findings till now?

Hey @SumanthDatta apologies for the late response here. Can you try with this change: it = ray.util.iter.from_items(rows_list)it = ray.util.iter.from_items(rows_list, repeat=True)