This PR adds a new CNN class in src/cnn.py that handles the MNIST dataset. The CNN class is responsible for loading and preprocessing the data, defining the CNN architecture, and training the model. The existing code in src/main.py for loading and preprocessing the MNIST dataset and defining the PyTorch model has been removed and replaced with the new CNN class. Additionally, the usage of the model in src/api.py has been updated to use the new CNN model instead of the previous Net model.
Summary of Changes
Created a new file src/cnn.py to contain the new CNN class.
Imported necessary libraries in src/cnn.py for building the CNN model.
Defined a new class CNN in src/cnn.py that inherits from torch.nn.Module.
Implemented the __init__ method in CNN class to define the layers of the CNN.
Implemented the forward method in CNN class to perform the forward pass of the CNN.
Defined a load_data function in src/cnn.py to load and preprocess the MNIST dataset.
Defined a train function in src/cnn.py to train the CNN model on the MNIST dataset.
Added a main function in src/cnn.py to create an instance of the CNN class, load the data, and train the model.
Updated src/main.py to import the CNN class from src/cnn.py.
Removed the code in src/main.py for loading and preprocessing the MNIST dataset and defining the PyTorch model.
Added code in src/main.py to create an instance of the CNN class and call the train method.
Updated src/api.py to import the CNN class from src/cnn.py.
Replaced the usage of the previous Net model with the new CNN model in src/api.py.
Updated the path to the state dict file in src/api.py to match the location of the saved CNN model.
Please review and merge this PR to incorporate the changes.
Fixes #9.
🎉 Latest improvements to Sweep:
Sweep can now passively improve your repository! Check out Rules to learn more.
💡 To get Sweep to edit this pull request, you can:
Comment below, and Sweep can edit the entire PR
Comment on a file, Sweep will only modify the commented file
Edit the original issue to get Sweep to recreate the PR from scratch
Description
This PR adds a new CNN class in
src/cnn.py
that handles the MNIST dataset. The CNN class is responsible for loading and preprocessing the data, defining the CNN architecture, and training the model. The existing code insrc/main.py
for loading and preprocessing the MNIST dataset and defining the PyTorch model has been removed and replaced with the new CNN class. Additionally, the usage of the model insrc/api.py
has been updated to use the new CNN model instead of the previous Net model.Summary of Changes
src/cnn.py
to contain the new CNN class.src/cnn.py
for building the CNN model.CNN
insrc/cnn.py
that inherits fromtorch.nn.Module
.__init__
method inCNN
class to define the layers of the CNN.forward
method inCNN
class to perform the forward pass of the CNN.load_data
function insrc/cnn.py
to load and preprocess the MNIST dataset.train
function insrc/cnn.py
to train the CNN model on the MNIST dataset.main
function insrc/cnn.py
to create an instance of the CNN class, load the data, and train the model.src/main.py
to import theCNN
class fromsrc/cnn.py
.src/main.py
for loading and preprocessing the MNIST dataset and defining the PyTorch model.src/main.py
to create an instance of theCNN
class and call the train method.src/api.py
to import theCNN
class fromsrc/cnn.py
.Net
model with the newCNN
model insrc/api.py
.src/api.py
to match the location of the saved CNN model.Please review and merge this PR to incorporate the changes.
Fixes #9.
🎉 Latest improvements to Sweep:
💡 To get Sweep to edit this pull request, you can: