一、 导入
1 import torch 2 from torch import nn 3 from d2l import torch as d2l 4 5 batch_size = 256 6 train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
二、初始化参数
# PyTorch不会隐式地调整输入的形状。因此, # 我们在线性层前定义了展平层(flatten),来调整网络输入的形状 # nn.Flatten() 将任何维度的tensor改成一个2d的tensor,第0维度保留,剩下的维度全部展成一个向量 net = nn.Sequential(nn.Flatten(), nn.Linear(784, 10)) def init_weights(m): if type(m) == nn.Linear: nn.init.normal_(m.weight, std=0.01) net.apply(init_weights);
三、Softmax的实现
1 loss = nn.CrossEntropyLoss()
四、优化算法
trainer = torch.optim.SGD(net.parameters(), lr=0.1)
五、训练
num_epochs = 10 d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)