import tensorflow as tf import numpy as np #保存权重 model_file = "./model/cifarmodel.h5" model.save_weights(model_file) print("已保存模型权重!") #加载权重 try: model.load_weights(model_file) print("权重加载成功!") except: print("权重加载失败!") #通过回调函数保存权重 check_points = "./model/cifar10.{epoch:02d} - {val_loss: .4f}.h5" callbacks = [ tf.keras.callbacks.ModelCheckpoint(filepath = check_points, save_weights_only = True, verbose = 0, save_freq = 'epoch'), tf.keras.callbacks.EarlyStopping(monitor = "val_loss", patience = 3), # 早停,防止过拟合,监控值为val_loss,如果连续三个周期的值都越来越差,则停止保存check_points ] #回调函数的应用 model.fit(train_x, train_y, validation_split = 0.3, epoch = 5, batch_size = 100, callbacks = callbacks, verbose = 2)