如何找到优化深度学习模型的方向
问题
当训练的模型的预测准确率(accuracy)不高,如何寻找模型的下一步调优方向?
通过对比训练集与开发(测试)集的准确率,可以确定模型的偏差(Bias)与方差(Variance)问题,给下一步模型的优化提供方向指导。
训练/开发/测试集
模型训练前通常将数据集划分为 训练集(Training set)、开发集(Dev set)与测试集(Test set)。训练集、开发集与测试集在数据集中的比例通常为6:2:2,而测试集与开发集的最大数量可以不用超过10000条:比如整个数据集有1000条,那么训练集、开发集与测试集可以分别为600条:200条:200条;如果整个数据集有10 000 000条数据,那么训练集、开发集与测试集可以分别为:9 980 000条:10 000条:10 000条。
模型使用训练集中的数据进行参数的训练,训练完成后,对开发集与测试集中的数据做验证。通过模型对训练集与开发集预测的准确率做统计对比,可以分析当前模型存在的欠拟合(Underfitting)/过拟合(Overfitting)问题,即高偏差(Bias)或者高方差(Variance)。
偏差(Bias)与方差(Variance)
假设通过房屋的面积预测价格,对于训练集,下图是比较合适的拟合曲线:因为随着面积的增加,房价会增加,但是增速会变缓。
偏差(Bias)
下图的预测与实际的情况存在较高的偏差(Bias),或者说下图的预测欠拟合(Underfitting),因为模型的预测与真实的房价存在较大的差距。模型的高偏差或者欠拟合表达的是相同的概念。
方差(Variance)
下图的预测与实际情况存在较高的方差(Variance),或者说下图的预测是过拟合(Overfitting),原因是虽然下图的曲线很好的模拟和训练集中的数据,但反应到真实的房价上看,面积越大反而价格越便宜,这不太符合实际情况(要是真的就好了:()。
问题定位
实际的模型比较复杂,无法通过上图中简单的图标定位模型的问题。因此,通过对比训练集与开发集的准确率(accuracy),也可以分析模型存在的问题是高偏差或者是高方差:
- 模型对于训练集和开发集的预测准确率都很高,恭喜模型很完美(下图右上角)。
- 模型对于训练集和开发集的准确率很低,属于高偏差问题:即模型无法正常拟合训练集中的数据,或者说模型对于训练集中的数据是**欠拟合(Underfitting)**的(左下角)。
- 模型对于训练集中的数据预测的准确率很高,但是对于开发集中的数据预测准确率较低:表明模型可以很好的预测它“看过”的数据,但对于它“没看过”的数据预测的准确率不高。即模型对于其训练集存在**过拟合(Overfitting)**问题(左上角)。
- 我想应该不会存在模型对训练集预测不高,对开发集预测很好的实际例子吧(右下角)? 如果存在,我想会不会是我的开发集出了问题?
高偏差(Bias)的优化
对于高偏差(Bias)或者欠拟合(Underfitting)问题,可以考虑如下优化方向:
- 更大/更深的网络结构,比如增加各层的神经单元,以及增大各神经单元的训练参数等。
- 训练更长的时间,尝试找到更优解。
- 更换网络架构,比如Yolo不行的话,试试看Mask RNN。
高方差(Variance)的优化
对于高方差(Variance)问题或者过拟合(Overfitting)问题,可以考虑如下优化方向:
- 获得更多的训练数据。
- 使用正则化(Reguliaztion)防止过拟合问题,如常用的L2 Regulization与Dropout等。
- 更换网络架构。
- 使用数据增强(Data Argumentation)。
- 早停法(early stop)。