3 softmax

import torch
import torchvision

def get_data(batch_size=50):
trans = torchvision.transforms.ToTensor()
mnist_train = torchvision.datasets.FashionMNIST(root="../data", train=True,
transform=trans,
download=True)
mnist_test = torchvision.datasets.FashionMNIST(root="../data", train=False,
transform=trans, download=True)
train = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True)
test = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False)
return train, test

train_iter, test_iter = get_data()

from d2l import torch as d2l

lr = 0.03
epoch = 50

1.net

def para_init(m):
if type(m) == torch.nn.Linear:
torch.nn.init.normal_(m.weight, mean=0.0, std=0.1)
net = torch.nn.Sequential(torch.nn.Flatten(), torch.nn.Linear(28 * 28, 10))
net.apply(para_init)

2 loss

loss = torch.nn.CrossEntropyLoss()

3 optimzer

op = torch.optim.SGD(net.parameters(), lr= lr)

def accuracy(y_pred, y):

print(y_pred)

if len(y_pred.shape) > 1 and y_pred.shape[1] > 1:
    y_pred = y_pred.argmax(axis=1)
cmp = (y_pred.type(y.dtype) == y)
return float(cmp.type(y.dtype).sum())

def evaluate_accuracy(net, test_iter):
if isinstance(net, torch.nn.Module):
net.eval()
all_accuracy, all_data = 0, 0
for X, y in test_iter:
all_accuracy += accuracy(net(X), y)
all_data += y.numel()
return all_accuracy / all_data

4 train

def train_epoch(net, train_iter, loss, op):
#/**返回平均loss 和 平均准确率 **/
if isinstance(net, torch.nn.Module):
net.train()

metric = d2l.Accumulator(3)    
for X, y in train_iter:
    y_pred = net(X)
    l = loss(y_pred, y)
    op.zero_grad()
    l.backward()
    op.step()
    metric.add(y.numel(), l * y.numel(), accuracy(y_pred, y))
return metric[1]/metric[0], metric[2]/metric[0]

def train(net, train_iter, test_iter, epoch, loss, op):
animator = d2l.Animator(xlabel='epoch', xlim=[1, epoch], ylim=[0.3, 0.9],
legend=['train loss', 'train acc', 'test acc'])
for i in range(epoch):
train_metric = train_epoch(net, train_iter, loss, op)
acc = evaluate_accuracy(net, test_iter)
animator.add(i + 1, train_metric + (acc,))

train(net, train_iter, test_iter, epoch, loss, op)

预测

def predict_ch3(net, test_iter, n=6): #@save
"""预测标签(定义见第3章)。"""
for X, y in test_iter:
break
trues = d2l.get_fashion_mnist_labels(y)
preds = d2l.get_fashion_mnist_labels(net(X).argmax(axis=1))
titles = [true + '\n' + pred for true, pred in zip(trues, preds)]
print(title)
d2l.show_images(X[0:n].reshape((n, 28, 28)), 1, n, titles=titles[0:n])

predict_ch3(net, test_iter)

3.7.6

1.尝试调整超参数,例如批量大小、迭代周期数和学习率,并查看结果。

2.增加迭代周期的数量。为什么测试准确率会在一段时间后降低?我们怎么解决这个问题?
过拟合,可以使用早停法

上一篇:AQS源码分析总结


下一篇:ReentrantLock源码