taoshengshi / specification

Specification of GitData Labs 's product portfolio.
https://gitdata.ai
MIT License
0 stars 0 forks source link

bs: Machine Learning Checkpointing #1

Open taoshengshi opened 8 months ago

taoshengshi commented 8 months ago

The following are the general steps for checking a model:

Checkpointing in machine learning is the technique of preserving intermediate models throughout the training process to resume training from the most recent point in the event of a system breakdown or stoppage. It entails regularly preserving a neural network’s or checkpoint machine learning model’s weights, biases, and other parameters during training, restoring the model to a prior state if training is halted or fails.

Checkpointing, in addition to allowing the restart of training in the event of a failure or interruption, may be beneficial for monitoring the development of a model during training and spotting possible concerns early on. Saving the model at regular intervals allows you to monitor the model’s performance over time and find patterns or anomalies that may need attention.

  1. Design the model architecture– Create your own deep learning model architecture or use pre-trained models.
  2. Optimizer and loss function– Choose the optimizer and loss function that will be utilized during training.
  3. Checkpointing directory– Set the directory where you want the model checkpoints saved.
  4. Checkpointing Callback– To store the model checkpoints, create a checkpointing callback object that will be invoked throughout training. This is possible with TensorFlow and Keras by using the ‘ModelCheckpoint’ function. To store the ckpt model in PyTorch, use the ‘torch.save()’ method.
  5. Form the model– Use the ‘fit()’ function in TensorFlow or Keras or the ‘train()’ method in PyTorch to train the deep learning model. The checkpointing callback will store model checkpoints at predefined intervals throughout training.
  6. Load the checkpoints– In TensorFlow and Keras, use the ‘load_weights()’ function or the torch to restart training from a prior checkpoint. To load the stored model checkpoints, use PyTorch’s torch.load() method.

To save time and resources and guarantee that your model is trained to its maximum potential, it is recommended to checkpoint deep learning models throughout training.

reference: Machine Learning Checkpointing https://deepchecks.com/glossary/machine-learning-checkpointing/

taoshengshi commented 8 months ago

目前已经支持feature 在这里: https://github.com/GitDataAI/jiaozifs/issues/162