使用PyTorch和Albumentations进行数据增强与损失函数

数据扩增

Part 1 数据读取与数据扩增

图像读取

常用的图像读取方法:OpenCV-python、Pillow、matplotlib.image、scipy.misc、skimage

  • Pillow只提供最基础的数字图像处理,功能有限,但方便轻巧
  • Scikit-image是基于scipy的一款图像处理包,功能强大
  • OpenCV是一个非常全面的图像处理、计算机视觉库
  • Pillow读入的图片是img类,其他库读进来的图片都是numpy矩阵
  • OpenCV读入的彩色图像通道顺序是BGR,其他图像库读入的彩色图像顺序都是RGB

数据扩增

数据扩增是一种有效的正则化方法,可以缓解模型过拟合,迫使网络学习到更鲁棒、更多样的特征,给模型带来更强的泛化能力。

使用PyTorch和Albumentations进行数据增强与损失函数
常用数据扩增技术分类:

推荐论文阅读:A survey on Image Data Augmentation for Deep Learning-2019
https://link.springer.com/article/10.1186/s40537-019-0197-0#Sec3

基于图像处理的数据扩增

几何变换

旋转、缩放、翻转、裁剪、平移、仿射变换

作用:几何变换可以有效的对抗数据中存在的位置偏差、视角偏差、尺寸偏差,而且易于实现
使用PyTorch和Albumentations进行数据增强与损失函数

灰度和彩色空间变换

亮度调整,对比度、饱和度调整,颜色空间转换,色彩调整,gamma变换

作用:对抗数据中存在的光照、色彩、亮度、对比度偏差

使用PyTorch和Albumentations进行数据增强与损失函数

添加噪声和滤波

  • 添加高斯噪声、椒盐噪声
  • 滤波:模糊、锐化、雾化

作用:应对噪声干扰、恶劣环境、成像异常等特殊情况,帮助学习更泛化的特征
使用PyTorch和Albumentations进行数据增强与损失函数

图像混合(Mixing images)

使用PyTorch和Albumentations进行数据增强与损失函数

随机搽除(Random erasing)

使用PyTorch和Albumentations进行数据增强与损失函数

基于深度学习的数据扩增

  • 基于GAN的数据增强(GAN-based Data Augmentation):使用GAN生成模型来生成更多的数据,可用做解决类别不平衡问题的过采样技术。
  • 神经风格转换(Neural Style Transfer):通过神经网络风格迁移来生成不同风格的数据,防止模型过拟合
  • AutoAugment

使用PyTorch进行数据增强

在PyTorch中,常用的数据增强的函数主要集成在torchvision.transforms

使用PyTorch和Albumentations进行数据增强与损失函数
使用PyTorch和Albumentations进行数据增强与损失函数
使用PyTorch进行数据增强

from PIL import Image
from torchvision import transforms as tfs
import matplotlib.pyplot as plt

im=Image.open('dog.jpg')
im_aug=tfs.Compose([
    tfs.Resize([200,200]),
    tfs.RandomVerticalFlip(),
    tfs.RandomCrop(110),
    tfs.ColorJitter(brightness=0.5,contrast=0.5,hue=0.5),
])

nrows=4
ncols=4
figsize=(8,8)
_,figs=plt.subplots(nrows,ncols,figsize=figsize)
for i in range(nrows):
    for j in range(ncols):
        figs[i][j].imshow(im_aug(im))
        figs[i][j].axes.get_xaxis().set_visible(False)
        figs[i][j].axes.get_yaxis().set_visible(False)
plt.show()

使用PyTorch和Albumentations进行数据增强与损失函数
使用PyTorch和Albumentations进行数据增强与损失函数

使用Albumentations进行数据增强

https://github.com/albumentations-team/albumentations

Part 2 评价与损失函数

IOU

T表示真实前景,P表示预测前景
使用PyTorch和Albumentations进行数据增强与损失函数
使用PyTorch和Albumentations进行数据增强与损失函数

Dice coefficient

使用PyTorch和Albumentations进行数据增强与损失函数
A表示真实前景,B表示预测前景,Dice系数取值范围为[0,1]

用来度量集合相似度的度量函数,通常用于计算两个样本之间的像素之间的相似度。

Dice系数不仅在直观上体现了target与prediction的相似程度,同时其本质上还隐含了精确率和召回率两个重要指标。

Dice Loss

通过Dice系数转变而来,为了能够实现最小化的损失函数,方便模型训练,以1-Dice的形式作为损失函数。
使用PyTorch和Albumentations进行数据增强与损失函数
在一些场合还可以添加上Laplace smoothing减少过拟合(为了解决零概率问题):
使用PyTorch和Albumentations进行数据增强与损失函数

Binary Cross-Entropy

使用PyTorch和Albumentations进行数据增强与损失函数
y:真实值,非1即0;

y ˉ \bar{y} yˉ​:所属此类的概率值,为预测值;

交叉熵损失函数可以用在大多数语义分割场景中,BCE损失函数(Binary Cross-Entropy Loss)是交叉损失函数(Cross-Entropy Loss)的一种特例,BCE Loss只能应用在二分类任务中,对于像素级的分类任务时效果不错。

缺点:当前景像素的数量远小于背景像素的数量时,可能会使得模型严重偏向背景,导致效果不佳。

Balanced Cross-Entropy

使用PyTorch和Albumentations进行数据增强与损失函数
y:真实值,非1即0;

y ˉ \bar{y} yˉ​:所属此类的概率值,为预测值;

设置 β > 1 \beta >1 β>1,减少假阴性;设置 β < 1 \beta <1 β<1,减少假阳性

优点:相比于原始的二元交叉熵Loss,在样本数量不均衡的情况下,可以获得更好的效果。

Focal Loss

使用PyTorch和Albumentations进行数据增强与损失函数
使用PyTorch和Albumentations进行数据增强与损失函数
Focal Loss最初是出现在目标检测领域,主要是为了解决正负样本、难易样本比例失调的问题。

简而言之, α \alpha α解决正负样本不平衡问题, γ \gamma γ解决难易样本不平衡问题。

使用PyTorch和Albumentations进行数据增强与损失函数
易分样本(即置信度高的样本),对模型的提升效果非常小,模型应该主要关注那些难分的样本。

上一篇:使用Jenkins来实现内部的持续集成流程(下)


下一篇:windows 系统命令 大全