本篇代码有不清楚的地方,可以参考:
cifar-10+resnet.
这篇除了搭建的CNN不一样,其他地方完全一样。
import torch
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
cifar_train = datasets.CIFAR10('cifar',True,transform=transforms.Compose(( #true表示加载的是训练集
transforms.Resize(32,32),
transforms.ToTensor())))
cifar_train_batch = DataLoader(cifar_train,batch_size = 30,shuffle = True)
cifar_test = datasets.CIFAR10('cifar',False,transform=transforms.Compose(( #false表示加载的是测试集
transforms.Resize(32,32),
transforms.ToTensor())))
cifar_test_batch = DataLoader(cifar_test_one,batch_size = 30,shuffle = True)
搭建CNN:
from torch import nn
class lenet5(nn.Module):
def __init__(self):
super(lenet5,self).__init__()
#两层卷积
self.conv_unit = nn.Sequential(
nn.Conv2d(in_channels=3,out_channels=6,kernel_size=5,stride=1,padding=0),
nn.AvgPool2d(kernel_size=2,stride=2,padding=0),
nn.Conv2d(in_channels=6,out_channels=16,kernel_size=5,stride=1,padding=0),
nn.AvgPool2d(kernel_size=2,stride=2,padding=0)
)
'''
卷积之后先flatten,然后再全连接,但是nn.Module中没有flatten的操作,
所以flatten不能包含在sequential中
因此先用一个sequential完成卷积操作,然后flatten,然后再用一个sequential完成全连接
'''
#全连接层
self.fc_unit = nn.Sequential(
nn.Linear(16*5*5,120),
nn.ReLU(),
nn.Linear(120,84),
nn.ReLU(),
nn.Linear(84,10)
)
def forward(self,x):
batch_size = x.shape[0]
x = self.conv_unit(x)
x = x.reshape(batch_size,16*5*5)
logits = self.fc_unit(x)
return logits
device = torch.device('cuda')
net = lenet5()
net = net.to(device) #将网络部署到GPU上
loss_fn = nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.Adam(net.parameters(),lr=1e-3)
#开始训练
for epoch in range(5):
for batchidx,(x,label) in enumerate(cifar_train_batch):
x,label = x.to(device),label.to(device) #x.size (bcs,3,32,32) label.size (bcs)
logits = net.forward(x)
loss = loss_fn(logits,label) #logits.size:bcs*10,label.size:bcs
#开始反向传播:
optimizer.zero_grad()
loss.backward() #计算gradient
optimizer.step() #更新参数
if (batchidx+1)%400 == 0:
print('这是本次迭代的第{}个batch'.format(batchidx+1)) #本例中一共有50000张照片,每个batch有30张照片,所以一个epoch有1667个batch
print('这是第{}迭代,loss是{}'.format(epoch+1,loss.item()))
这是本次迭代的第400个batch
这是本次迭代的第800个batch
这是本次迭代的第1200个batch
这是本次迭代的第1600个batch
这是第1迭代,loss是1.1926124095916748
这是本次迭代的第400个batch
这是本次迭代的第800个batch
这是本次迭代的第1200个batch
这是本次迭代的第1600个batch
这是第2迭代,loss是1.1064329147338867
这是本次迭代的第400个batch
这是本次迭代的第800个batch
这是本次迭代的第1200个batch
这是本次迭代的第1600个batch
这是第3迭代,loss是0.8839625120162964
这是本次迭代的第400个batch
这是本次迭代的第800个batch
这是本次迭代的第1200个batch
这是本次迭代的第1600个batch
这是第4迭代,loss是1.0676394701004028
这是本次迭代的第400个batch
这是本次迭代的第800个batch
这是本次迭代的第1200个batch
这是本次迭代的第1600个batch
这是第5迭代,loss是1.0307090282440186
#测试
net.eval()
with torch.no_grad():
correct_num = 0 #预测正确的个数
total_num = 0 #测试集中总的照片张数
batch_num = 0 #第几个batch
for x,label in cifar_test: #x的size是30*3*32*32(30是batch_size,3是通道数),label的size是30.
#cifar_test中一共有10000张照片,所以一共有334个batch,因此要循环334次
x,label = x.to(device),label.to(device)
logits = net.forward(x)
pred = logits.argmax(dim=1)
correct_num += torch.eq(pred,label).float().sum().item()
total_num += x.size(0)
batch_num += 1
if batch_num%50 == 0:
print('这是第{}个batch'.format(batch_num)) #一共有10000/30≈334个batch
acc = correct_num/total_num #最终的total_num是10000
print('测试集上的准确率为:',acc)
这是第50个batch
这是第100个batch
这是第150个batch
这是第200个batch
这是第250个batch
这是第300个batch
测试集上的准确率为: 0.5496