pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
82.42k stars 22.18k forks source link

Dataloader resetting with num_workers=1 and persistent_workers=True #110029

Open jmzeng opened 1 year ago

jmzeng commented 1 year ago

🐛 Describe the bug

I hit the following bug when initializing a pytorch Dataloader with num_workers=1 and persistent_workers=True. It seems the Dataloader gets reset:

UserWarning: Length of IterableDataset <ConstantLengthDataset object at 0x7f4f3d028790> was reported to be 55936 (when accessing len(dataloader)), but 69087 samples have been fetched. For multiprocessing data-loading, this could be caused by not properly configuring the IterableDataset replica at each worker.

I also tried using the worker_init_fn as suggested, where dataset.start and .end are the start and end indices. It still hits the above warning:

def worker_init_fn(worker_id):
    worker_info = torch.utils.data.get_worker_info()
    dataset = worker_info.dataset  # the dataset copy in this worker process
    overall_start = dataset.start
    overall_end = dataset.end
    # configure the dataset to only process the split workload
    per_worker = int(math.ceil((overall_end - overall_start) / float(worker_info.num_workers)))
    worker_id = worker_info.id
    dataset.start = overall_start + worker_id * per_worker
    dataset.end = min(dataset.start + per_worker, overall_end)
    logger.info(f"Initializing worker with split of data from {dataset.start} to {dataset.end}")

From the training perspective, I hit this error about every 2.46 epochs and the epochs get reset to 1 and then it continues training from there, but I have a suspicion it's skipping part of the data. I worry this might have affected training.

Versions

PyTorch version: 2.0.1+cu118 Is debug build: False CUDA used to build PyTorch: 11.8 ROCM used to build PyTorch: N/A

Versions of relevant libraries: [pip3] numpy==1.24.4 [pip3] pytorch-lightning==2.0.8 [pip3] pytorch-ranger==0.1.1 [pip3] torch==2.0.1+cu118 [pip3] torch-optimizer==0.3.0 [pip3] torchdata==0.6.1 [pip3] torchmetrics==0.11.4 [pip3] torchtext==0.15.2+cpu [pip3] torchvision==0.15.2+cu118 [pip3] triton==2.0.0 [pip3] triton-pre-mlir==2.0.0 [conda] Could not collect

cc @SsnL @VitalyFedyunin @ejguan @dzhulgakov

gokulavasan commented 7 months ago

@jmzeng In case of num_workers=1, the entire dataset (all the samples in it) would be read by just that worker (I am assuming number of ranks = 1 in this case). If that is the case, I am wondering how is the IterableDataset able to send 69087 samples if the length that it returns is 55936. How are you implementing the len method of IterableDataset implemented in your usecase?