## ❓ Questions and Help
### Before asking:
1. Try to find answers to your q…uestions in [the Lightning Forum!](https://forums.pytorchlightning.ai/)
2. Search for similar [issues](https://github.com/PyTorchLightning/pytorch-lightning/issues).
3. Search the [docs](https://pytorch-lightning.readthedocs.io/en/latest/).
#### What is your question?
I try to use custom IterationBasedBatchSampler.
When I try to trainer.fit(model, train_dataloader) with 'ddp' and multi gpu below error message is happened.
How can I use this IterationBasedBatchSampler?
Or what are the check points do I have to?
```
Traceback (most recent call last):
File "lea_train_net.py", line 69, in <module>
main(args)
File "lea_train_net.py", line 58, in main
trainer.fit(model, data_loader_train, data_loader_valid)
File "/usr/local/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 470, in fit
results = self.accelerator_backend.train()
File "/usr/local/lib/python3.7/site-packages/pytorch_lightning/accelerators/ddp_accelerator.py", line 143, in train
results = self.ddp_train(process_idx=self.task_idx, model=model)
File "/usr/local/lib/python3.7/site-packages/pytorch_lightning/accelerators/ddp_accelerator.py", line 298, in ddp_train
results = self.train_or_test()
File "/usr/local/lib/python3.7/site-packages/pytorch_lightning/accelerators/accelerator.py", line 65, in train_or_test
results = self.trainer.train()
File "/usr/local/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 521, in train
self.train_loop.run_training_epoch()
File "/usr/local/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py", line 554, in run_training_epoch
for batch_idx, (batch, is_last_batch) in train_dataloader:
File "/usr/local/lib/python3.7/site-packages/pytorch_lightning/profiler/profilers.py", line 83, in profile_iterable
value = next(iterator)
File "/usr/local/lib/python3.7/site-packages/pytorch_lightning/trainer/connectors/data_connector.py", line 46, in _with_is_last
last = next(it)
File "/usr/local/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 345, in __next__
data = self._next_data()
File "/usr/local/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 856, in _next_data
return self._process_data(data)
File "/usr/local/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 881, in _process_data
data.reraise()
File "/usr/local/lib/python3.7/site-packages/torch/_utils.py", line 394, in reraise
raise self.exc_type(msg)
KeyError: Caught KeyError in DataLoader worker process 0.
Original Traceback (most recent call last):
File "/usr/local/lib/python3.7/site-packages/torch/utils/data/_utils/worker.py", line 178, in _worker_loop
data = fetcher.fetch(index)
File "/usr/local/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 47, in fetch
return self.collate_fn(data)
File "/usr/local/lib/python3.7/site-packages/torch/utils/data/_utils/collate.py", line 45, in default_collate
elem = batch[0]
KeyError: 0
```
#### Code
My custom IterationBasedBatchSampler code is below
```
class IterationBasedBatchSampler(torch.utils.data.sampler.BatchSampler):
"""
Wraps a BatchSampler, resampling from it until
a specified number of iterations have been sampled
"""
def __init__(self, batch_sampler, num_iterations, start_iter=0):
self.batch_sampler = batch_sampler
self.num_iterations = num_iterations
self.start_iter = start_iter
def __iter__(self):
iteration = self.start_iter
while iteration <= self.num_iterations:
# if the underlying sampler has a set_epoch method, like
# DistributedSampler, used for making each process see
# a different split of the dataset, then set it
if hasattr(self.batch_sampler.sampler, "set_epoch"):
self.batch_sampler.sampler.set_epoch(iteration)
for batch in self.batch_sampler:
iteration += 1
if iteration > self.num_iterations:
break
yield batch
def __len__(self):
return self.num_iterations
```
And my train code is below.
```
trainer = pl.Trainer(
gpus=2,
distributed_backend='ddp',
val_check_interval=1000,
max_epochs=1,
checkpoint_callback=checkpoint_callback,
logger=logger)
trainer.fit(model,
fake_dataloader_train,
fake_dataloader_valid)
```
#### What have you tried?
#### What's your environment?
- OS: [e.g. iOS, Linux, Win]
- Packaging [e.g. pip, conda]
- Version [e.g. 0.5.2.1]