1. 加载MINST数据集
python代码:
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda, Compose
import matplotlib.pyplot as plt
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
# load train data
train_data = datasets.MNIST(
root='data', # save in the directory data
train=True, # True means train data, False means test data
download=True,
transform=ToTensor() # (0, 255) to (0, 1)
)
# load test data
test_data = datasets.MNIST(
root='data',
train=False, #True means train data, False means test data
download=True,
transform=ToTensor()
)
下载成功结果:
添加下列代码显示其中一条数据:
print(train_data.data.size())
print(train_data.targets.size())
plt.imshow(train_data.data[130].numpy(), cmap='gray')
plt.title('%i' % train_data.targets[130])
plt.show()
数据如下图所示:
2. 快速构建CNN网络
python代码:
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(
in_channels=1, # size of input channels
out_channels=16, # size of input channels
kernel_size=(5, 5), # size of filter
stride=(1, 1), # step of filter
padding=2, # padding num
),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2),
)
self.conv2 = nn.Sequential(
nn.Conv2d(16, 32, (5, 5), (1, 1), 2),
nn.ReLU(),
nn.MaxPool2d(2),
)
self.flat = nn.Flatten() # flattern the result
self.out = nn.Linear(32 * 7 * 7, 10)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = self.flat(x)
out = self.out(x)
return out