fine-tuning of VGG

一、 fine-tuning

由于数据集的限制,我们可以使用预训练的模型,来重新fine-tuning(微调)。

使用卷积网络作为特征提取器,冻结卷积操作层,这是因为卷积层提取的特征对于许多任务都有用处,使用新的数据集训练新定义的全连接层。

何时以及如何Fine-tune

决定如何使用迁移学习的因素有很多,这是最重要的只有两个:新数据集的大小、以及新数据和原数据集的相似程度。有一点一定记住:网络前几层学到的是通用特征,后面几层学到的是与类别相关的特征。这里有使用的四个场景:

1、新数据集比较小且和原数据集相似。因为新数据集比较小,如果fine-tune可能会过拟合;又因为新旧数据集类似,我们期望他们高层特征类似,可以使用预训练网络当做特征提取器,用提取的特征训练线性分类器。

2、新数据集大且和原数据集相似。因为新数据集足够大,可以fine-tune整个网络。

3、新数据集小且和原数据集不相似。新数据集小,最好不要fine-tune,和原数据集不类似,最好也不使用高层特征。这时可是使用前面层的特征来训练SVM分类器。

4、新数据集大且和原数据集不相似。因为新数据集足够大,可以重新训练。但是实践中fine-tune预训练模型还是有益的。新数据集足够大,可以fine-tine整个网络。

我们这次的作业属于数据集与原来相似而且数据集很小的情况,可以使用预训练网络当做特征提取器,用提取的特征训练线性分类器。

二、代码重点

数据处理

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

vgg_format = transforms.Compose([
                transforms.CenterCrop(224), 
    			#.中心裁剪:transforms.CenterCrop
				#class torchvision.transforms.CenterCrop(size)
				#功能:依据给定的size从中心裁剪
				#参数:
				#size- (sequence or int),若为sequence,则为(h,w),若为int,则(size,size)
                transforms.ToTensor(),
                normalize,
            ])

data_dir = './dogscats'

dsets = {x: datasets.ImageFolder(os.path.join(data_dir, x), vgg_format)
         for x in ['train', 'valid']}

dset_sizes = {x: len(dsets[x]) for x in ['train', 'valid']}
dset_classes = dsets['train'].classes

修改全连接层,冻结卷积层的参数

for param in model_vgg_new.parameters():
    param.requires_grad = False #训练时不更改参数
model_vgg_new.classifier._modules['6'] = nn.Linear(4096, 2) #全连接输出两类猫或者狗
model_vgg_new.classifier._modules['7'] = torch.nn.LogSoftmax(dim = 1) # 数据处理

创建损失函数和优化器,训练模型

criterion = nn.NLLLoss() #设置损失函数

lr = 0.001 # 学习率

optimizer_vgg = torch.optim.SGD(model_vgg_new.classifier[6].parameters(),lr = lr) # 随机梯度下降

#训练模型 (模板 建议直接背诵)
def train_model(model,dataloader,size,epochs=1,optimizer=None):
    model.train()
    
    for epoch in range(epochs):
        running_loss = 0.0
        running_corrects = 0
        count = 0
        for inputs,classes in dataloader:
            inputs = inputs.to(device)
            classes = classes.to(device)
            outputs = model(inputs)
            loss = criterion(outputs,classes)           
            optimizer = optimizer
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            _,preds = torch.max(outputs.data,1)
            # statistics
            running_loss += loss.data.item()
            running_corrects += torch.sum(preds == classes.data)
            count += len(inputs)
            print('Training: No. ', count, ' process ... total: ', size)
        epoch_loss = running_loss / size
        epoch_acc = running_corrects.data.item() / size
        print('Loss: {:.4f} Acc: {:.4f}'.format(
                     epoch_loss, epoch_acc))

三、代码优化

1.数据处理

vgg_format = transforms.Compose([
                #transforms.CenterCrop(224), 
    			transforms.Resize((224,224)),
    #这里选择缩放而不是中心裁剪,因为简单地选择重心裁剪会让图像的一些特征直接丢失,严重的情况下直接无法捕捉到物体(cat or dog),这样的情况下卷积也没有什么作用了
                transforms.ToTensor(),
                normalize,
            ])

下面的图片就可以看出使用缩放而不是中心裁剪的原因:原本图片的

上一篇:hdu1232 并查集总结


下一篇:Bert微调技巧实验大全-How to Fine-Tune BERT for Text Classification