Open CHH3213 opened 2 years ago
https://blog.csdn.net/LuohenYJ/article/details/81096886
代码示例
from __future__ import absolute_import, division, print_function import os import tensorflow as tf from tensorflow import keras tf.__version__ (train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data() train_labels = train_labels[:1000] test_labels = test_labels[:1000] train_images = train_images[:1000].reshape(-1, 28 * 28) / 255.0 test_images = test_images[:1000].reshape(-1, 28 * 28) / 255.0 # 模型创建模型 def create_model(): model = tf.keras.models.Sequential([ keras.layers.Dense(512, activation=tf.nn.relu, input_shape=(784,)), keras.layers.Dropout(0.2), keras.layers.Dense(10, activation=tf.nn.softmax) ]) model.compile(optimizer=tf.keras.optimizers.Adam(), loss=tf.keras.losses.sparse_categorical_crossentropy, metrics=['accuracy']) return model #创建模型 model = create_model() model.summary() checkpoint_path = "training_1/cp.ckpt" checkpoint_dir = os.path.dirname(checkpoint_path) #创建回调函数 cp_callback = tf.keras.callbacks.ModelCheckpoint(checkpoint_path, save_weights_only=True, #只保存权重 verbose=1) model = create_model() model.fit(train_images, train_labels, epochs = 10, validation_data = (test_images,test_labels), callbacks = [cp_callback]) #保存模型 #对全新没有训练的模型进行预测 model = create_model() loss, acc = model.evaluate(test_images, test_labels) print("Untrained model, accuracy: {:5.2f}%".format(100*acc)) #11.4% #载入权重参数后的模型 model.load_weights(checkpoint_path) loss,acc = model.evaluate(test_images, test_labels) print("Restored model, accuracy: {:5.2f}%".format(100*acc)) #86.2 # 保存权重 model.save_weights('./checkpoints/my_checkpoint') #恢复模型 model = create_model() model.load_weights('./checkpoints/my_checkpoint') loss,acc = model.evaluate(test_images, test_labels) print("Restored model, accuracy: {:5.2f}%".format(100*acc)) #87.00% #将整个模型保存为HDF5文件 # model = create_model() # model.fit(train_images, train_labels, epochs=5) # model.save('my_model.h5') # #载入一个相同的模型 # new_model = keras.models.load_model('my_model.h5') # new_model.summary() # loss, acc = new_model.evaluate(test_images, test_labels) # print("Restored model, accuracy: {:5.2f}%".format(100*acc)) #86.30%
保存权重后生成的目录:
https://blog.csdn.net/LuohenYJ/article/details/81096886
代码示例
保存权重后生成的目录: