图像识别实战(一)----数据集的预处理

图像识别实战(一)----数据集的预处理

1.模块的导入

import os
import matplotlib.pyplot as plt

import numpy as np
import torch
from torch import nn
import torch.optim as optim
import torchvision
from torchvision import transforms,models,datasets
import imageio
import time
import warnings
import random
import sys
import copy
import json
from PIL import Image

数据集的读取

data_dir = './flower_data'
train_dir = data_dir+ '/train'
valid_dir = data_dir+ '/valid'

2.数据集的预处理

data_transforms = {
    'train': transforms.Compose([transforms.RandomRotation(45),#随机旋转,-45度到45度之间
                                transforms.CenterCrop(224),#从中心开始裁剪
                                transforms.RandomHorizontalFlip(p=0.5),#随机水平翻转 选择一个概率
                                transforms.RandomVerticalFlip(p=0.5),#随机垂直翻转
                                transforms.ColorJitter(brightness=0.2, contrast=0.1,saturation=0.1, hue=0.1),#参数1为亮度,参数2为对比度,参数3为饱和度,参数4为色相
                                transforms.RandomGrayscale(p=0.025),#概率转化为灰度值,3通道就是R=G=B
                                transforms.ToTensor(),#转化为Tensor格式,在预处理结束后必须添加
                                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),#均值,标准差,经过这样处理后的数据符合标准正态分布,即均值为0,标准差为1。使模型更容易收敛。
    'valid': transforms.Compose([transforms.Resize(256),
                                transforms.CenterCrop(224),
                                transforms.ToTensor(),
                                transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])]), 

transforms.Compose()这个类的主要功能是串联图片的变换操作,类似于一个列表。

3.数据集的组织与加载

batch_size = 8
image_datasets = {x:datasets.ImageFolder(os.path.join(data_dir, x),data_transforms[x]) for x in ['train','valid']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True) for x in ['train', 'valid']}#就是用来包装所使用的数据,每次抛出一批数据
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'valid']}
class_names = image_datasets['train'].classes
dataset=torchvision.datasets.ImageFolder(
                       root, #图片存储的根目录
                       transform=None, #图片的预处理操作
                       target_transform=None, #对图片类别做预处理操作
                       loader=<function default_loader>, #数据集加载方式
                       is_valid_file=None)#获取图像文件的路径并检查该文件是否为有效文件
#print(dataset.classes)  #根据分的文件夹的名字来确定的类别
#print(dataset.class_to_idx) #按顺序为这些类别定义索引为0,1...
#print(dataset.imgs) #返回从所有文件夹中得到的图片的路径以及其类别

我们打印出 image_datasets

{'train': Dataset ImageFolder
     Number of datapoints: 3614
     Root location: F:/flower_data/train
     StandardTransform
 Transform: Compose(
                RandomRotation(degrees=(-45, 45), resample=False, expand=False)
                CenterCrop(size=(224, 224))
                RandomHorizontalFlip(p=0.5)
                RandomVerticalFlip(p=0.5)
                ColorJitter(brightness=[0.8, 1.2], contrast=[0.9, 1.1], saturation=[0.9, 1.1], hue=[-0.1, 0.1])
                RandomGrayscale(p=0.025)
                ToTensor()
                Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ),
 'valid': Dataset ImageFolder
     Number of datapoints: 56
     Root location: F:/flower_data/valid
     StandardTransform
 Transform: Compose(
                Resize(size=256, interpolation=PIL.Image.BILINEAR)
                CenterCrop(size=(224, 224))
                ToTensor()
                Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            )}

我们打印出 dataset_sizes 帮助理解{}中的逻辑

{'train': 3614, 'valid': 56}

4.数据集图像展示

def im_convert(tensor):
    """展示数据"""
    image = tensor.to('cpu').clone().detach()#将Tensor数据从GPU放到CPU,复制和这个Tensor并且去掉梯度
    image = image.numpy().squeeze()#祛除数组中为1 的维度
    image = image.transpose(1,2,0)#Pytorch中为[Channels, H, W],而plt.imshow()中则是[H, W, Channels],所以交换一下通道
    image = image*np.array((0.229, 0.224, 0.225))+np.array((0.485, 0.456, 0.406))# 反转一下transforms.Normalize()的过程
    image = image.clip(0, 1)#归一化
    return image
fig = plt.figure(figsize=(20, 12))#设置图像尺寸
columns = 4
rows = 2
#我们设置的一个batchsize=8,所以dataloaders里只有8张图片,最多显示8张图片
dataiter = iter(dataloaders['valid'])#iter()迭代器
inputs, classes = dataiter.next()
for idx in range (columns*rows):
    ax = fig.add_subplot(rows,columns, idx+1,xticks=[], yticks=[])#图像区域划分row行,colums列,第idx+1个
    ax.set_title(class_names[classes[idx].item()])
    plt.imshow(im_convert(inputs[idx]))

plt.show()   

图像识别实战(一)----数据集的预处理

上一篇:我的猫狗大战分类


下一篇:Github上标星82.1K+star面试笔记,可以帮你搞定95%以上的Java面试,已经帮助多人拿下offer