This PR refactors the training loop from a script to a class for better organization and reusability. The existing code in src/main.py is encapsulated into a new class named MNISTTrainer. The class is responsible for loading and preprocessing the MNIST data, defining the PyTorch model, training the model, and saving/loading the model for future use.
Summary of Changes
Created a new class named MNISTTrainer in src/main.py.
Moved the code for loading and preprocessing the MNIST data into the load_data method of the MNISTTrainer class.
Moved the code for defining the PyTorch model into the define_model method of the MNISTTrainer class.
Implemented the train method in the MNISTTrainer class, which contains the training loop for the model.
Implemented the save_model method in the MNISTTrainer class, which saves the trained model to a file.
Implemented the load_model method in the MNISTTrainer class, which loads the model from a file.
Updated the code in src/api.py to use the load_model method of the MNISTTrainer class for loading the model.
Please review and merge this PR to complete the requested refactoring of the training loop.
Fixes #6.
🎉 Latest improvements to Sweep:
Sweep can now passively improve your repository! Check out Rules to learn more.
💡 To get Sweep to edit this pull request, you can:
Comment below, and Sweep can edit the entire PR
Comment on a file, Sweep will only modify the commented file
Edit the original issue to get Sweep to recreate the PR from scratch
Description
This PR refactors the training loop from a script to a class for better organization and reusability. The existing code in
src/main.py
is encapsulated into a new class namedMNISTTrainer
. The class is responsible for loading and preprocessing the MNIST data, defining the PyTorch model, training the model, and saving/loading the model for future use.Summary of Changes
MNISTTrainer
insrc/main.py
.load_data
method of theMNISTTrainer
class.define_model
method of theMNISTTrainer
class.train
method in theMNISTTrainer
class, which contains the training loop for the model.save_model
method in theMNISTTrainer
class, which saves the trained model to a file.load_model
method in theMNISTTrainer
class, which loads the model from a file.src/api.py
to use theload_model
method of theMNISTTrainer
class for loading the model.Please review and merge this PR to complete the requested refactoring of the training loop.
Fixes #6.
🎉 Latest improvements to Sweep:
💡 To get Sweep to edit this pull request, you can: