Keras中观察训练时的学习率方法

class Metrics(keras.callbacks.Callback):
    def __init__(self, valid_data):
        super(Metrics, self).__init__()
        self.validation_data = valid_data

    def on_epoch_end(self, epoch, logs=None):
        logs = logs or {}
        val_predict = np.argmax(self.model.predict(self.validation_data[0]), -1)
        val_targ = self.validation_data[1]

#         val_predict = onehot_enc.inverse_transform(val_predict_dim4.reshape(-1, 4))
#         val_targ = onehot_enc.inverse_transform(val_targ_dim4.reshape(-1, 4))
        
        if len(val_targ.shape) == 2 and val_targ.shape[1] != 1:
            val_targ = np.argmax(val_targ, -1)

        _val_f1 = f1_score(val_targ, val_predict, average='macro')
        _val_recall = recall_score(val_targ, val_predict, average='macro')
        _val_precision = precision_score(val_targ, val_predict, average='macro')

        logs['val_f1'] = _val_f1
        logs['val_recall'] = _val_recall
        logs['val_precision'] = _val_precision
        print(" — val_f1: %f — val_precision: %f — val_recall: %f - lr: %f" % (_val_f1, 
                                                                               _val_precision,                                                                                                                                   _val_recall,                                                                                                                     keras.backend.get_value(self.model.optimizer.lr)))
        return

 reduce_lr = keras.callbacks.ReduceLROnPlateau(monitor='val_f1',factor=0.8,verbose=1,
                                              patience=15,min_delta=0.001)

history = model.fit(X_train_input, y_train_dim4,
                    validation_data=(X_val_input, y_val_dim4),
                    batch_size=128, epochs=150,  ## epochs=60,150
                    shuffle=False,
                    verbose=0,
                    callbacks=[reduce_lr,
                               early_stop,
                               Metrics(valid_data=(X_val_input, y_val)),
                               ck_callback],
                   class_weight='auto')     

结果:

1、训练中一轮epcho结束时查看
Epoch 00004: val_f1 improved from 0.94381 to 0.95022, saving model to ./checkpoints3/weights.04-0.9502.hdf5
 — val_f1: 0.948708 — val_precision: 0.957801 — val_recall: 0.941037 - lr: 0.001000

2、学习率动态调整时查看变化后的学习率:

Epoch 00149: ReduceLROnPlateau reducing learning rate to 0.00016777217388153076.

参考:

https://blog.csdn.net/qq_43258953/article/details/103356187

https://blog.csdn.net/qq_42699580/article/details/105365782

https://www.cnpython.com/qa/218335

上一篇:面向对象-基础


下一篇:MVC架构中的Repository模式 个人理解