TensorFlow自定义评估指标

TensorFlow内置常用指标:

  • AUC()
  • Precision()
  • Recall()
  • 等等

有些时候我们的指标不止这些,需要根据我们自己特定的任务指定自己的评估指标,这时就需要自定义Metric,需要子类化Metric,也就是继承keras.metrics.Metric,然后实现它的方法:

  • __init__:这个方法是用来初始化一些变量的
  • update_state:参数有真实值、预测值,采样权重,我们需要在这个方法内进行更新状态变量
  • result:使用状态变量计算最终的评估结果
  • reset_states:重新初始化状态变量

下方实现的评估指标是计算有多少个正类被评估正确,就是预测对了多少样本

完整代码:

"""
 * Created with PyCharm
 * 作者: 阿光
 * 日期: 2022/1/2
 * 时间: 18:32
 * 描述:
"""
import tensorflow as tf
import tensorflow.keras.datasets.mnist
from keras import Input, Model
from keras.layers import Dense
from tensorflow import keras

(train_images, train_labels), (val_images, val_labels) = tensorflow.keras.datasets.mnist.load_data()

train_images, val_images = train_images / 255.0, val_images / 255.0

train_images = train_images.reshape(60000, 784)
val_images = val_images.reshape(10000, 784)


class CategoricalTruePositives(keras.metrics.Metric):
    def __init__(self, name="categorical_true_positives"):
        super(CategoricalTruePositives, self).__init__(name=name)
        self.true_positives = self.add_weight(name='ctp', initializer='zeros')

    def update_state(self, y_true, y_pred, sample_weight=None):
        y_pred = tf.reshape(tf.argmax(y_pred, axis=1), shape=(-1, 1))
        values = tf.cast(y_true, 'int32') == tf.cast(y_pred, 'int32')
        values = tf.cast(values, 'float32')
        self.true_positives.assign_add(tf.reduce_sum(values))

    def result(self):
        return self.true_positives

    def reset_states(self):
        self.true_positives.assign(0.0)


def get_model():
    inputs = Input(shape=(784,))
    outputs = Dense(10, activation='softmax')(inputs)
    model = Model(inputs, outputs)
    model.compile(
        optimizer=keras.optimizers.RMSprop(learning_rate=1e-3),
        loss=keras.losses.SparseCategoricalCrossentropy(),
        metrics=[CategoricalTruePositives()]
    )
    return model


model = get_model()

model.fit(
    train_images,
    train_labels,
    epochs=5,
    batch_size=32,
    validation_data=(val_images, val_labels)
)

上一篇:对batch求算loss时loss是张量形式或是标量有什么不同?


下一篇:MNIST手写数字识别案例TensorFlow 2.0 实践