上回讲到Lenet5的分类问题。
用pytorch实现卷积神经网络Lenet5的CIFAR10分类问题
今天讲一下 一个非常简单的RESNET,总共十层,包括一个卷积层,四个blog(每个blog两个卷积层),一个全连接层。blog的基本单元如图所示
每个blog有一个短接x最后和F(x) 相加。
以下是blog块的定义
import torch
from torch.nn import functional as F
from torch import nn
class ResBlk(nn.Module):
def __init__(self,ch_in,ch_out,stride):
super(ResBlk,self).__init__()
self.extra = nn.Sequential(nn.Conv2d(ch_in, ch_out, kernel_size=1, stride=stride), nn.BatchNorm2d(ch_out))
# 改变stride是为了使得图片的size变小,以避免占用过多内存
self.conv1 = nn.Conv2d(ch_in,ch_out,kernel_size = 3,stride = stride,padding = 1)
self.bn1 = nn.BatchNorm2d(ch_out)
self.conv2 = nn.Conv2d(ch_out,ch_out,kernel_size = 3,stride = 1,padding = 1)
self.bn2 = nn.BatchNorm2d(ch_out)
def forward(self,x):
out = F.relu(self.bn1(self.conv1(x)))
# 这里的relu取决于自己
out = F.relu(self.bn2(self.conv2(out)))
# short cut
# extra module: [b,ch_in,h,w] --> [b,ch_in,h,w]
# element-wise add 需要ch_in和ch_out相等
# 由于是残差网络,所以要把f(x)和短路的x相加
out = self.extra(x) + out
return out
以下是将一个卷积层和四个blog块加一个全连接层串联起来
class ResNet10(nn.Module):
def __init__(self):
super(ResNet10,self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(3,64,kernel_size = 3,stride = 1,padding = 1),
nn.BatchNorm2d(64)
)
# followws 4 blocks
# [b,64,h,w] --> [b,128,h,w]
self.blk1 = ResBlk(64,128,stride = 2)
# [b,128,h,w] --> [b,256,h,w]
self.blk2 = ResBlk(128,256,stride = 2)
# [b,256,h,w] --> [b,512,h,w]
self.blk3 = ResBlk(256,512,stride = 2)
# [b,512,h,w] --> [b,512,h,w]
self.blk4 = ResBlk(512,512,stride = 2)
# 线性层的输入需要测试之后才能知道
self.outlayer = nn.Linear(512*1*1,10)
def forward(self,x):
x = F.relu(self.conv1(x))
# [b,64,h,w] --> [b,1024,h,w]
x = self.blk1(x)
x = self.blk2(x)
x = self.blk3(x)
x = self.blk4(x)
#print('after conv:',x.shape) # [b,512,2,2]
# [b,512,1,1] --> [b,512,1,1]
x = F.adaptive_avg_pool2d(x,[1,1])
#print('after conv:',x.shape)
x = x.view(x.size(0),-1)
x = self.outlayer(x)
return x
以下是主程序代码
import torch
import torchvision
from torch import nn, optim
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader
#from lenet5 import Lenet5
from resnet import ResNet10
classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
def main():
training= datasets.CIFAR10('data',True,transform=transforms.Compose([transforms.Resize((32,32)),transforms.ToTensor()]),download=True)
trainloader = DataLoader(training, batch_size=32,shuffle=True)
test= datasets.CIFAR10('data',False,transform=transforms.Compose([transforms.Resize((32,32)),transforms.ToTensor()]),download=True)
testloader = DataLoader(test, batch_size=16,shuffle=True)
#x, label=iter(training).next()
#print('x:',x.shape,'label:',label.shape)
device=torch.device('cuda')
model=ResNet10().to(device)
criteon=nn.CrossEntropyLoss()
optimizer=optim.Adam(model.parameters(),lr=0.001)
print(model)
for epoch in range(25):
model.train()
for batchix,(x,label) in enumerate(trainloader):
x,label=x.to(device),label.to(device)
logits=model(x)
loss=criteon(logits,label)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(epoch,loss.item())
model.eval()
class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))
with torch.no_grad():
total_correct=0
total_num=0
for x ,label in testloader:
x,label=x.to(device),label.to(device)
logits=model(x)
pred= logits.argmax(dim=1)
total_correct+=torch.eq(pred,label).float().sum().item()
total_num+=x.size(0)
c = (pred == label).squeeze()
for i in range(16):
labels = label[i]
class_correct[labels] += c[i].item()
class_total[labels] += 1
acc = total_correct / total_num
print(epoch, acc)
for i in range(10):
print('Accuracy of %5s : %2d %%' % (
classes[i], 100 * class_correct[i] / class_total[i]))
if __name__=='__main__':
main()
跑了25个epoch能够达到81.28%多的识别率比lenet5 提高了13%,如果持续跑下去可到82%多