bearpaw / pytorch-classification

Classification with PyTorch.
MIT License
1.69k stars 563 forks source link

Refactor Outdated PyTorch Operations for AlexNet Training on CIFAR #57

Open KaledDahleh opened 4 months ago

KaledDahleh commented 4 months ago

This pull request includes several important updates and fixes to ensure compatibility with the latest version of PyTorch and to resolve issues related to tensor operations while training AlexNet on the CIFAR dataset.

The key changes are as follows:

  1. Removed async=True argument from cuda() calls: This argument is deprecated in PyTorch 1.0.0 and has been removed to maintain compatibility.

    • Commit: 5c70d06
  2. Refactored code to use torch.no_grad() instead of volatile=True: Updated to the new context manager approach to handle operations without tracking gradients for better efficiency with the latest PyTorch version.

    • Commit: 4d393ff
  3. Fixed tensor reshaping in the accuracy function: Replaced view with reshape to correct issues when reshaping tensors, improving reliability.

    • Commit: d378d87
  4. Addressed all 0-dim tensor indexing errors in cifar.py: Accessing 0-dim tensors caused index errors, ensuring smooth operation during training and metrics update.

    • Commits: c570bd6, 2935578

Additional Information: