5、实战:CIFAR-10分类

与其他教程不一样的地方是加载的本地已下载数据(代码中下载速度太慢)。关于数据集的说明点击此链接

1、下载数据集,复制此链接到迅雷下载 http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz

2、解压到E:/data目录中,

5、实战:CIFAR-10分类

3、jupyter中撸代码

【说明】

①   Dataloader是一个可迭代的对象,它将datasets返回的每一条数据样本拼接成一个batch,并可以多线程加速优化、数据打乱等操作。

  当datasets的所有数据遍历一遍之后,对Dataloader也完成了一次迭代。

②    Dataloader类型与datasets一样返回数据、标签索引,不过Dataloader是以batch为单位返回的,即batch_size个元素组成的向量。

【步骤】

1、使用torchvision加载并预处理数据集
2、定义网络
3、定义损失函数和优化器
4、训练网络并更新网络参数
5、测试网络

######################################## 1、使用torchvision加载并预处理数据集
import torch as t
import torchvision as tv
import torchvision.transforms as transforms
from torchvision.transforms import ToPILImage

#定义对数据的预处理
transform=transforms.Compose([
    transforms.ToTensor(), #转为Tensor
    transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))]) #归一化,到[-1,1]
#训练集
#torchvision datasets 的输出是 PILImage 类型的 images, 像素取值范围在 [0, 1]。 
#我们将其变换为归一化范围在[-1 , 1] 的 Tensors 类型
trainset=tv.datasets.CIFAR10(
    root="E:/data", #注意我的路径
    train=True,
    download=False, #注意不下载
    transform=transform) #载入数据的时候,转换为Tensor类型,并归一化到[-1,1]
trainloader=t.utils.data.DataLoader(
    trainset,
    batch_size=4, #batch大小
    shuffle=True, #打乱
    num_workers=2) #2个线程
#测试集
testset=tv.datasets.CIFAR10(
    root="E:/data",
    train=False,
    download=False,
    transform=transform)
testloader=t.utils.data.DataLoader(
    testset,
    batch_size=4,
    shuffle=False,
    num_workers=2)
#类别标签
classes=("plane","car","bird","cat","deer","dog","frog","horse","ship","truck")

显示datasets中的第一幅图

(data,label)=trainset[0] #获取数据集的第一个元素,返回数据和标签索引
print(classes[label]) #label是6
show=ToPILImage() #把Tensor转成Image,方便可视化
show((data+1)/2).resize((100,100)) #(data+1)/2反归一化[-1,1]为[0,1],resize把32×32图变为100×100

5、实战:CIFAR-10分类

显示Dataloader中的第一批图(4幅图为1个batch)

# 随机获取4张训练集图片(trainloader中4个元素为一批(batch))
dataiter = iter(trainloader)
(images, labels) = dataiter.next() #以batch为单位,Dataloader类型与datasets一样返回数据、标签索引
print(labels)
# 输出对应的标签
print(' '.join('%s' % classes[labels[j]] for j in range(4)))
# 显示图像
show(tv.utils.make_grid((images+1)/2)).resize((400,100)) #显示4个图
#show((images[0]+1)/2).resize((100,100)) #显示第一个图

5、实战:CIFAR-10分类

 

上一篇:人工智能能力提升指导总结


下一篇:人工智能能力提升指导总结