pytorch复习与总结

今天来复习pytorch的数据读取机制
torch.utils.data.DataLoader();构建可迭代的数据装载器,每一个for 循环,每一个iteration,都是从DataLoader中获取一个Batch_size大小的数据。
有没有好奇过,就加载这几个类,然后就可以把数据读取,而且还能以批量的形式加载,这是怎样的一个过程呢?今天我们就来慢慢的深入学习,学到哪是哪。
pytorch复习与总结

其中DataLoader大概有几个重要的参数,分别为:
1、dataset:属于DataSet类,决定数据从哪读取,怎么读取。
2、num_works:是否多进程读取
3、shuffle:每个epoch是否是乱序
4、batchsize:批量大小
5、drop_last:组成批量是,多余的是不是要剔除掉,
先来理解epoch
所有训练样本都已经输入到模型中,称为一个epoch,iteration:一批样本输入到模型中
再来理解Batchsize:
批量的大小,决定了一个epoch有多少个iteration
Dataset复习
torch.utils.data.Dataset(): Dataset抽象类,所有自定义的Dataset都必须继承他,并且还要复习__getitem__()和__len__()这两个函数,那么第一个函数是干啥用的呢?第二个是干啥用的呢,我这里通过学习查资料,理解了这么一个过程:
getitem:这个函数主要是来收集并返回图片和标签信息的,这个函数有两个参数,

def __getitem__(self, item):

其中item是干啥的呢?这个就是一个索引,很重的一个参数,我们在这个函数里读取信息的时候就是根据这个item 参数来寻找每张图片的信息的,其中过程可以在他的父类中看到
pytorch复习与总结
那么这个重写的函数是要收集哪些信息呢?
他的作用是收集我们训练集或者测试集的图片和图片所对应的标签号码,而且把图片信息转换为张量信息也是在这个函数里面发生的,转换完之后会返回出去
pytorch复习与总结
他把信息返回到哪里了呢?
那就是返回到你创建的这个自定义的数据类里面了
pytorch复习与总结
只有这样,你才能实例化对象把这写数据打包成批量或者单个。
**len_()**返回的就是数据集的长度,这个是很简单的return len(self.image)
我们在自定义数据集的时候还有一点很重要
我们怎么收集图片呢?
pytorch复习与总结
来看这张图片,这长图片表达的意思是:我要获取图片的具体位置(自己在__init___已经设置好了)和图片对应的标签,把获取后的信息随机打乱一下,放入到两个列表里面,并返回出去,返回给谁呢?
pytorch复习与总结
返回给__init__里面自定义的两个变量,这两个变量负责将数据包里面的内容根据你设定的训练测试返回来剪辑数据的大小。是不是再想剪辑后干啥呢?放到哪里呢?
还记得上边我们提到的__getitem__(self, item)?
pytorch复习与总结
返回给了他(红线圈起来的),然后继续往下执行,张量化。这就是一整个过程
下边张贴一下所有的代码过程。

import csv
import glob
import os
import random

import torch
import torchvision
import visdom
from PIL import Image
from torch.utils.data.dataset import Dataset
from torchvision import transforms
from torch.utils.data import DataLoader

class myData(Dataset):
    def __init__(self,file,size,mode):
        self.file=file
        self.size=size
        self.label_name={}#存放文件名称和标号
        for name in (os.listdir(os.path.join(file))):
            if os.path.isdir(os.path.join(self.file,name)):
                self.label_name[name]=len(self.label_name.keys())
        #print(self.label_name)
        self.imgaes,self.labels=self.get_img_info('image_csv')
        #--------------划分训练集和测试集范围--------------
        if mode=='train':
            self.images=self.imgaes[:int(0.8*len(self.imgaes))]
            self.labels=self.labels[:int(0.8*len(self.labels))]
        else:
            self.images=self.imgaes[int(0.8*len(self.imgaes)):]
            self.labels=self.labels[int(0.8*len(self.labels)):]
        #---------------划分训练集和测试集范围-------------
        pass
    def __len__(self):
        return len(self.images)
    pass
    def get_img_info(self,filename):#这个时获得图片信息的函数
        images=[]
        labels=[]
        for name in self.label_name.keys():#取出字典的键
            images+=glob.glob(os.path.join(self.file,name,'*.jpg'))
            images+=glob.glob(os.path.join(self.file,name,'*png'))
            #print(images)
            pass
        random.shuffle(images)#把这里面所有的地址和都打乱
        with open(os.path.join(self.file,filename),mode='w',newline='') as f:
            writer=csv.writer(f)
            for file in images:
                img=file.split(os.sep)[-2]
                label=self.label_name[img]
                labels.append(label)
                writer.writerow([file,label])
                pass
            pass
        #print(images)
        return images,labels
    def __getitem__(self, item):
        image=self.images[item]
        label=self.labels[item]
        tf=torchvision.transforms.Compose([
            lambda x:Image.open(x).convert('RGB'),
            transforms.Resize((int(self.size),int(self.size))),
            transforms.RandomRotation(15),
            transforms.CenterCrop(64),
            transforms.ToTensor()#这个放在最后操作,前边那几个是在图片的基础上修改的,这个把修改好的再转化为张量
        ])
        img=tf(image)
        label=torch.tensor(label)
        return img,label


def main():
    #viz=visdom.Visdom()
    mydata_train = myData('traindata', 64, 'train')
    mydata_test=myData('traindata', 64, 'test')
    #x,y=next(iter(mydata))
    #viz.image(x,win='sample_x',opts=dict(title='sample_x'))
    train=DataLoader(mydata_train,batch_size=32,shuffle=True)#把数据打包成批量Batchsize
    test=DataLoader(mydata_test,batch_size=32)
    '''
        print(train)
    for x,y in train:
        viz.images(x,nrow=8,win='batch',opts=dict(title='bacht'))
    '''

if __name__=='__main__':
    main()

这里还有几个知识点要记录:
数据增强:
对数据集进行变换,让模型更具有泛化能力,比如

       transforms.RandomRotation(15),
       transforms.CenterCrop(64),

上边这两个操作,具体的可以去网上查找

transforms.ToTensor()

把图像转为张量,同时进行归一化操作,将张量从0-255转到0-1之间

transforms.Normalize()

加快模型的收敛速度
总结:
今天复习了自定义数据的收集过程,到底是一个怎么收集的过程,然后就是一步一步的介绍了整体收集的过程。

上一篇:python图像读取速度


下一篇:springboot-注解汇总,狂刷200道数据结构与算法