import torch
import torchvision
def get_data(batch_size=50):
trans = torchvision.transforms.ToTensor()
mnist_train = torchvision.datasets.FashionMNIST(root="../data", train=True,
transform=trans,
download=True)
mnist_test = torchvision.datasets.FashionMNIST(root="../data", train=False,
transform=trans, download=True)
train = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True)
test = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False)
return train, test
train_iter, test_iter = get_data()
from d2l import torch as d2l
lr = 0.03
epoch = 50
1.net
def para_init(m):
if type(m) == torch.nn.Linear:
torch.nn.init.normal_(m.weight, mean=0.0, std=0.1)
net = torch.nn.Sequential(torch.nn.Flatten(), torch.nn.Linear(28 * 28, 10))
net.apply(para_init)
2 loss
loss = torch.nn.CrossEntropyLoss()
3 optimzer
op = torch.optim.SGD(net.parameters(), lr= lr)
def accuracy(y_pred, y):
print(y_pred)
if len(y_pred.shape) > 1 and y_pred.shape[1] > 1:
y_pred = y_pred.argmax(axis=1)
cmp = (y_pred.type(y.dtype) == y)
return float(cmp.type(y.dtype).sum())
def evaluate_accuracy(net, test_iter):
if isinstance(net, torch.nn.Module):
net.eval()
all_accuracy, all_data = 0, 0
for X, y in test_iter:
all_accuracy += accuracy(net(X), y)
all_data += y.numel()
return all_accuracy / all_data
4 train
def train_epoch(net, train_iter, loss, op):
#/**返回平均loss 和 平均准确率 **/
if isinstance(net, torch.nn.Module):
net.train()
metric = d2l.Accumulator(3)
for X, y in train_iter:
y_pred = net(X)
l = loss(y_pred, y)
op.zero_grad()
l.backward()
op.step()
metric.add(y.numel(), l * y.numel(), accuracy(y_pred, y))
return metric[1]/metric[0], metric[2]/metric[0]
def train(net, train_iter, test_iter, epoch, loss, op):
animator = d2l.Animator(xlabel='epoch', xlim=[1, epoch], ylim=[0.3, 0.9],
legend=['train loss', 'train acc', 'test acc'])
for i in range(epoch):
train_metric = train_epoch(net, train_iter, loss, op)
acc = evaluate_accuracy(net, test_iter)
animator.add(i + 1, train_metric + (acc,))
train(net, train_iter, test_iter, epoch, loss, op)
预测
def predict_ch3(net, test_iter, n=6): #@save
"""预测标签(定义见第3章)。"""
for X, y in test_iter:
break
trues = d2l.get_fashion_mnist_labels(y)
preds = d2l.get_fashion_mnist_labels(net(X).argmax(axis=1))
titles = [true + '\n' + pred for true, pred in zip(trues, preds)]
print(title)
d2l.show_images(X[0:n].reshape((n, 28, 28)), 1, n, titles=titles[0:n])
predict_ch3(net, test_iter)
3.7.6
1.尝试调整超参数,例如批量大小、迭代周期数和学习率,并查看结果。
2.增加迭代周期的数量。为什么测试准确率会在一段时间后降低?我们怎么解决这个问题?
过拟合,可以使用早停法