Open jmzeng opened 1 year ago
@jmzeng In case of num_workers=1, the entire dataset (all the samples in it) would be read by just that worker (I am assuming number of ranks = 1 in this case). If that is the case, I am wondering how is the IterableDataset able to send 69087 samples if the length that it returns is 55936. How are you implementing the len method of IterableDataset implemented in your usecase?
🐛 Describe the bug
I hit the following bug when initializing a pytorch Dataloader with
num_workers=1
andpersistent_workers=True
. It seems the Dataloader gets reset:UserWarning: Length of IterableDataset <ConstantLengthDataset object at 0x7f4f3d028790> was reported to be 55936 (when accessing len(dataloader)), but 69087 samples have been fetched. For multiprocessing data-loading, this could be caused by not properly configuring the IterableDataset replica at each worker.
I also tried using the
worker_init_fn
as suggested, wheredataset.start
and.end
are the start and end indices. It still hits the above warning:From the training perspective, I hit this error about every 2.46 epochs and the epochs get reset to 1 and then it continues training from there, but I have a suspicion it's skipping part of the data. I worry this might have affected training.
Versions
PyTorch version: 2.0.1+cu118 Is debug build: False CUDA used to build PyTorch: 11.8 ROCM used to build PyTorch: N/A
Versions of relevant libraries: [pip3] numpy==1.24.4 [pip3] pytorch-lightning==2.0.8 [pip3] pytorch-ranger==0.1.1 [pip3] torch==2.0.1+cu118 [pip3] torch-optimizer==0.3.0 [pip3] torchdata==0.6.1 [pip3] torchmetrics==0.11.4 [pip3] torchtext==0.15.2+cpu [pip3] torchvision==0.15.2+cu118 [pip3] triton==2.0.0 [pip3] triton-pre-mlir==2.0.0 [conda] Could not collect
cc @SsnL @VitalyFedyunin @ejguan @dzhulgakov