keras损失函数详解

以下信息均来自官网

------------------------------------------------------------------------------------------------------------

损失函数的使用

损失函数(或称目标函数、优化评分函数)是编译模型时所需的两个参数之一:

model.compile(loss='mean_squared_error', optimizer='sgd')
from keras import losses

model.compile(loss=losses.mean_squared_error, optimizer='sgd')

 

你可以传递一个现有的损失函数名,或者一个 TensorFlow/Theano 符号函数。 该符号函数为每个数据点返回一个标量,有以下两个参数:

  • y_true: 真实标签。TensorFlow/Theano 张量。
  • y_pred: 预测值。TensorFlow/Theano 张量,其 shape 与 y_true 相同。

实际的优化目标是所有数据点的输出数组的平均值。

可用损失函数

mean_squared_error

mean_squared_error(y_true, y_pred)

mean_absolute_error

mean_absolute_error(y_true, y_pred)

mean_absolute_percentage_error

mean_absolute_percentage_error(y_true, y_pred)

mean_squared_logarithmic_error

mean_squared_logarithmic_error(y_true, y_pred)

squared_hinge

squared_hinge(y_true, y_pred)

hinge

hinge(y_true, y_pred)

categorical_hinge

categorical_hinge(y_true, y_pred)

logcosh

logcosh(y_true, y_pred)

预测误差的双曲余弦的对数。

对于小的 xlog(cosh(x)) 近似等于 (x ** 2) / 2。对于大的 x,近似于 abs(x) - log(2)。这表示 'logcosh' 与均方误差大致相同,但是不会受到偶尔疯狂的错误预测的强烈影响。

参数

  • y_true: 目标真实值的张量。
  • y_pred: 目标预测值的张量。

返回

每个样本都有一个标量损失的张量。


categorical_crossentropy

categorical_crossentropy(y_true, y_pred)

sparse_categorical_crossentropy

sparse_categorical_crossentropy(y_true, y_pred)

binary_crossentropy

binary_crossentropy(y_true, y_pred)

kullback_leibler_divergence

kullback_leibler_divergence(y_true, y_pred)

poisson

poisson(y_true, y_pred)

cosine_proximity

cosine_proximity(y_true, y_pred)

注意: 当使用 categorical_crossentropy 损失时,你的目标值应该是分类格式 (即,如果你有 10 个类,每个样本的目标值应该是一个 10 维的向量,这个向量除了表示类别的那个索引为 1,其他均为 0)。 为了将 整数目标值 转换为 分类目标值,你可以使用 Keras 实用函数 to_categorical

from keras.utils.np_utils import to_categorical

categorical_labels = to_categorical(int_labels, num_classes=None)

如果还不明白,请看下面的源码
  1 """Built-in loss functions.
  2 """
  3 from __future__ import absolute_import
  4 from __future__ import division
  5 from __future__ import print_function
  6 
  7 import six
  8 from . import backend as K
  9 from .utils.generic_utils import deserialize_keras_object
 10 from .utils.generic_utils import serialize_keras_object
 11 
 12 
 13 def mean_squared_error(y_true, y_pred):
 14     return K.mean(K.square(y_pred - y_true), axis=-1)
 15 
 16 
 17 def mean_absolute_error(y_true, y_pred):
 18     return K.mean(K.abs(y_pred - y_true), axis=-1)
 19 
 20 
 21 def mean_absolute_percentage_error(y_true, y_pred):
 22     diff = K.abs((y_true - y_pred) / K.clip(K.abs(y_true),
 23                                             K.epsilon(),
 24                                             None))
 25     return 100. * K.mean(diff, axis=-1)
 26 
 27 
 28 def mean_squared_logarithmic_error(y_true, y_pred):
 29     first_log = K.log(K.clip(y_pred, K.epsilon(), None) + 1.)
 30     second_log = K.log(K.clip(y_true, K.epsilon(), None) + 1.)
 31     return K.mean(K.square(first_log - second_log), axis=-1)
 32 
 33 
 34 def squared_hinge(y_true, y_pred):
 35     return K.mean(K.square(K.maximum(1. - y_true * y_pred, 0.)), axis=-1)
 36 
 37 
 38 def hinge(y_true, y_pred):
 39     return K.mean(K.maximum(1. - y_true * y_pred, 0.), axis=-1)
 40 
 41 
 42 def categorical_hinge(y_true, y_pred):
 43     pos = K.sum(y_true * y_pred, axis=-1)
 44     neg = K.max((1. - y_true) * y_pred, axis=-1)
 45     return K.maximum(0., neg - pos + 1.)
 46 
 47 
 48 def logcosh(y_true, y_pred):
 49     """Logarithm of the hyperbolic cosine of the prediction error.
 50     `log(cosh(x))` is approximately equal to `(x ** 2) / 2` for small `x` and
 51     to `abs(x) - log(2)` for large `x`. This means that 'logcosh' works mostly
 52     like the mean squared error, but will not be so strongly affected by the
 53     occasional wildly incorrect prediction.
 54     # Arguments
 55         y_true: tensor of true targets.
 56         y_pred: tensor of predicted targets.
 57     # Returns
 58         Tensor with one scalar loss entry per sample.
 59     """
 60     def _logcosh(x):
 61         return x + K.softplus(-2. * x) - K.log(2.)
 62     return K.mean(_logcosh(y_pred - y_true), axis=-1)
 63 
 64 
 65 def categorical_crossentropy(y_true, y_pred):
 66     return K.categorical_crossentropy(y_true, y_pred)
 67 
 68 
 69 def sparse_categorical_crossentropy(y_true, y_pred):
 70     return K.sparse_categorical_crossentropy(y_true, y_pred)
 71 
 72 
 73 def binary_crossentropy(y_true, y_pred):
 74     return K.mean(K.binary_crossentropy(y_true, y_pred), axis=-1)
 75 
 76 
 77 def kullback_leibler_divergence(y_true, y_pred):
 78     y_true = K.clip(y_true, K.epsilon(), 1)
 79     y_pred = K.clip(y_pred, K.epsilon(), 1)
 80     return K.sum(y_true * K.log(y_true / y_pred), axis=-1)
 81 
 82 
 83 def poisson(y_true, y_pred):
 84     return K.mean(y_pred - y_true * K.log(y_pred + K.epsilon()), axis=-1)
 85 
 86 
 87 def cosine_proximity(y_true, y_pred):
 88     y_true = K.l2_normalize(y_true, axis=-1)
 89     y_pred = K.l2_normalize(y_pred, axis=-1)
 90     return -K.sum(y_true * y_pred, axis=-1)
 91 
 92 
 93 # Aliases.
 94 
 95 mse = MSE = mean_squared_error
 96 mae = MAE = mean_absolute_error
 97 mape = MAPE = mean_absolute_percentage_error
 98 msle = MSLE = mean_squared_logarithmic_error
 99 kld = KLD = kullback_leibler_divergence
100 cosine = cosine_proximity
101 
102 
103 def serialize(loss):
104     return serialize_keras_object(loss)
105 
106 
107 def deserialize(name, custom_objects=None):
108     return deserialize_keras_object(name,
109                                     module_objects=globals(),
110                                     custom_objects=custom_objects,
111                                     printable_module_name='loss function')
112 
113 
114 def get(identifier):
115     """Get the `identifier` loss function.
116     # Arguments
117         identifier: None or str, name of the function.
118     # Returns
119         The loss function or None if `identifier` is None.
120     # Raises
121         ValueError if unknown identifier.
122     """
123     if identifier is None:
124         return None
125     if isinstance(identifier, six.string_types):
126         identifier = str(identifier)
127         return deserialize(identifier)
128     if isinstance(identifier, dict):
129         return deserialize(identifier)
130     elif callable(identifier):
131         return identifier
132     else:
133         raise ValueError('Could not interpret '
134                          'loss function identifier:', identifier)

 

 

 
上一篇:CTPN训练自己的数据集过程大白话记录


下一篇:AbstractQueuedSynchronizer(AQS)源码分析