mrdbourke / pytorch-deep-learning

Materials for the Learn PyTorch for Deep Learning: Zero to Mastery course.
https://learnpytorch.io
MIT License
10.78k stars 3.17k forks source link

problems with setting up dataloader #1061

Open lazyman001 opened 1 month ago

lazyman001 commented 1 month ago

I set my dataloader like this :

 NUM_WOEKERS = os.cpu_count()

 train_dataloader= DataLoader(
     dataset=train_data_simple,
     batch_size=BATCH_SIZE,
     shuffle=True,
     num_workers=1
 ) 

and an error is reported when the num_workcer equals any number that is not zero. The error is shown in the picture. e7c6a01289149ea83eab3a006cf6373

mrdbourke commented 1 month ago

Hey @lazyman001 ,

Where are you getting this issue?

What's the code you're trying to run?

Have you tried putting all your code into a main() function? And then calling if __name__ == "__main__": main()?

For example:


This error occurs because you're trying to use multiple worker processes in PyTorch, and you haven't properly protected the code that starts these processes. Specifically, you need to ensure that the multiprocessing module's main entry point is properly guarded by an if __name__ == '__main__': clause. This is required when using the multiprocessing module on platforms that don’t use the fork system call (such as Windows).

To fix this, make sure your script looks something like this:

python
Copy code
import torch
from torch.utils.data import DataLoader

# Your other imports and code here

def main():
    # Your training or data loading code here
    # Example:
    # dataset = YourDataset()
    # dataloader = DataLoader(dataset, num_workers=4)

    pass

if __name__ == '__main__':
    main()

Someone had a similar issue to this the other day, see: https://github.com/mrdbourke/pytorch-deep-learning/discussions/1059

If you're still having troubles, please post the code you're trying to run and describe where you're running it.

heisdenverr commented 1 week ago

This error occurs because you're trying to use multiple worker processes in PyTorch, you created a global constant called NUM WORKERS, and in the data.dataloader implementation , you hard coded it to be 1.

To fix this, make sure your script looks something like this:

python Copy code import torch from torch.utils.data import DataLoader

Your other imports and code here


def main():
    # Your training or data loading code here
    # Example:
    # dataset = YourDataset()
    # NUM_WORKERS = os.cpu_count()
    # dataloader = DataLoader(dataset, num_workers=NUM_WORKERS)

    pass

if __name__ == '__main__':
    main()