paddlepaddle 9 MC Dropout的使用

MC Dropout是指蒙特卡罗Dropout,其可以在不改就网络结构与增加训练的情况下在测试阶段提升模型的性能,本质就是在测试时将dropout一直处于激活阶段。对网络进行多次前向传播,由于dropout每一次激活的神经元都不同,使得每次的结果都会不一样。将多次输出的结果取平均值,可以在一定程度上提升算法的准确性,但是会降低算法的推理速度。

在paddlepaddle 中是无法在测试阶段将dropout处于激活状态的,其根本原因是paddlepaddle中的dropout中的参数列表中无法指定train与eval,其参数列表如下所示:paddle.nn.Dropout(p=0.5axis=Nonemode="upscale_in_train”name=None)

  • p (float): 将输入节点置为0的概率, 即丢弃概率。默认: 0.5。

  • axis (int|list): 指定对输入 Tensor 进行Dropout操作的轴。默认: None。

  • mode (str): 丢弃单元的方式,有两种'upscale_in_train'和'downscale_in_infer',默认: 'upscale_in_train'。计算方法如下:

    1. upscale_in_train, 在训练时增大输出结果。

      • train: out = input * mask / ( 1.0 - p )

      • inference: out = input

    2. downscale_in_infer, 在预测时减小输出结果

      • train: out = input * mask

      • inference: out = input * (1.0 - p)

  • name (str,可选): 操作的名称(可选,默认值为None)。更多信息请参见 Name 

因此,要实现dropout的激活状态只能通过model.train()来使模型中的dropout处于激活状态,但是设置model.train()后,笔者发现模型前向传播过程中的gpu占用是无法被清除,无论batch_size调多小,只要测试的数据一多,就会导致显存不够用。因此,参考模型训练过程的显存清空方式,实现MC_Dropout。

import paddle
#paddle.set_flags({'FLAGS_eager_delete_tensor_gb': 1.0})
paddle.set_flags({'FLAGS_fast_eager_deletion_mode': True })#使用快速垃圾回收策略,实际中没有任何作用
def MC_Dropout(model,data,times=10):#蒙特卡罗Dropout
    model.train()
    #为了反向传播中的清空显存功能,learning_rate为0表示不让模型进行参数更新
    optim = paddle.optimizer.Adam(learning_rate=0.0,parameters=model.parameters())
    loss_fn = paddle.nn.CrossEntropyLoss(soft_label=True)#使用软标签    
    result=[]
    for i in range(times):
        out=model(x_data)
        preds=paddle.nn.functional.softmax(out)
        result.append(preds.numpy())
        #借用反向传播释放内存
        loss = loss_fn(out, out)
        loss.backward()
        optim.step()
        optim.clear_grad()
    result=np.array(result)#shape为t,b,c t:times,b:batch_size,c:class_probability
    result=np.transpose(result,(1,2,0))#shape为b,c,t
    result=result.sum(axis=-1)#shape为b,c
    result=result.argmax(axis=-1)#shape为b,c
    return result
model=paddle.jit.load("model/ep125_loss0.400336_acc0.9306model")
model.train()
Imagetest=ImageClsTestDataset((256,256),"data/data10954/cat_12_test")
BATCH_SIZE=12
# 如果要加载内置数据集,将 custom_dataset 换为 train_dataset 即可
train_loader = paddle.io.DataLoader(Imagetest, batch_size=BATCH_SIZE, shuffle=False)
print('=============train model=============')
count=0
results=[]
name_list=[]
for batch_id, data in enumerate(train_loader()):
    x_data = data[0]
    names = data[1]
    results+=MC_Dropout(model,x_data,times=1).tolist()
    name_list+=names   
    count+=len(names)
print(count)
print(results)

显存清空方式说明: 基于paddle.optimizer和loss函数进行假反向传播(让学习率为0,loss为0),使优化器不会对模型的参数进行实质更新。

数据加载器ImageClsTestDataset的实现:传入图片路径,自动加载到list中,无需生成txt列表

import paddle
from paddle.io import Dataset
from paddle.vision import transforms
from PIL import Image
import numpy as np
import os

class ImageClsTestDataset(Dataset):
    def __init__(self,input_shape,root):
        super(ImageClsTestDataset, self).__init__()
        self.input_shape=input_shape
        #ToTensor将形状为 (H x W x C)的输入数据 PIL.Image 或 numpy.ndarray 转换为 (C x H x W),并进行归一化。如果想保持形状不变,可以将参数 data_format 设置为 'HWC'
        #在paddle模型中数据是CHW的格式
        self.preprocess_image=transforms.Compose([
                    transforms.Resize(input_shape),
                    transforms.ToTensor(),
                    transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225),data_format='CHW')
                ])
        self.root=root
        self.lists=os.listdir(root)
        self.length = len(self.lists)

    def __getitem__(self, index):
        name = self.root+'/'+self.lists[index]
        image      = Image.open(name)
        np_img     = np.array(image)
        if len(np_img.shape)==2:#防止数据中存在灰度图
            tmp=np.ones((*np_img.shape,3))
            tmp[:,:,0]=np_img
            tmp[:,:,1]=np_img
            tmp[:,:,2]=np_img
            np_img=tmp 
        np_img=np_img.astype(np.uint8) 
        image      = self.preprocess_image(np_img[:,:,:3])#np_img[:,:,:3]防止数据中存在RGBA的四通道数据
        return image, name

    def __len__(self):
        return self.length

上一篇:PyQt5基础学习-QTimer(时间计时器) 1.QDateTime.currentDateTime(显示当前时间) 2.QTimer().start(设置时间的间隔) 3.QTimer().sto


下一篇:【lzy学习笔记-dive into deep learning】3.1线性回归 3.2 从零开始的实现