This PR refactors the training loop from a script to a class for better organization and reusability. The training loop, which was previously located in src/main.py, has been encapsulated within a new class named MNISTTrainer. The class is responsible for loading and preprocessing the MNIST dataset, defining the PyTorch model, and training the model. It also provides a method to save the trained model to a file.
Summary of Changes
Created a new class named MNISTTrainer in src/main.py to encapsulate the training loop.
Moved the code for loading and preprocessing the MNIST data into a method named load_data in MNISTTrainer.
Moved the code for defining the PyTorch model into a method named define_model in MNISTTrainer.
Implemented a new method named train_model in MNISTTrainer to contain the training loop.
Created a new method named save_model in MNISTTrainer to save the trained model to a file.
Updated the import statement in src/api.py to import the MNISTTrainer class instead of the Net class.
Updated the code in src/api.py to use the MNISTTrainer class for loading the data, defining the model, training the model, and saving the model to a file.
Please review and merge this PR. Thank you!
Fixes #6.
🎉 Latest improvements to Sweep:
Sweep can now passively improve your repository! Check out Rules to learn more.
💡 To get Sweep to edit this pull request, you can:
Comment below, and Sweep can edit the entire PR
Comment on a file, Sweep will only modify the commented file
Edit the original issue to get Sweep to recreate the PR from scratch
Description
This PR refactors the training loop from a script to a class for better organization and reusability. The training loop, which was previously located in
src/main.py
, has been encapsulated within a new class namedMNISTTrainer
. The class is responsible for loading and preprocessing the MNIST dataset, defining the PyTorch model, and training the model. It also provides a method to save the trained model to a file.Summary of Changes
MNISTTrainer
insrc/main.py
to encapsulate the training loop.load_data
inMNISTTrainer
.define_model
inMNISTTrainer
.train_model
inMNISTTrainer
to contain the training loop.save_model
inMNISTTrainer
to save the trained model to a file.src/api.py
to import theMNISTTrainer
class instead of theNet
class.src/api.py
to use theMNISTTrainer
class for loading the data, defining the model, training the model, and saving the model to a file.Please review and merge this PR. Thank you!
Fixes #6.
🎉 Latest improvements to Sweep:
💡 To get Sweep to edit this pull request, you can: