模型权重的保存与加载 回调函数的使用

import tensorflow as tf
import numpy as np
#保存权重
model_file = "./model/cifarmodel.h5"
model.save_weights(model_file)
print("已保存模型权重!")

#加载权重
try:
    model.load_weights(model_file)
    print("权重加载成功!")
except:
    print("权重加载失败!")
    
    
    
#通过回调函数保存权重
check_points = "./model/cifar10.{epoch:02d} - {val_loss: .4f}.h5"
callbacks = [ tf.keras.callbacks.ModelCheckpoint(filepath = check_points, 
                                                 save_weights_only = True,
                                                 verbose = 0, save_freq = epoch),
             tf.keras.callbacks.EarlyStopping(monitor = "val_loss", patience = 3),
             # 早停,防止过拟合,监控值为val_loss,如果连续三个周期的值都越来越差,则停止保存check_points
]
#回调函数的应用
model.fit(train_x, train_y, validation_split = 0.3, epoch = 5, batch_size = 100, callbacks = callbacks, verbose = 2)
    

 

模型权重的保存与加载 回调函数的使用

上一篇:WSL2 中 安装卸载 Docker


下一篇:列表、元组