网络
1. ModuleList 和 Sequential
# 1.使用torch.nn.ModuleList定义网络。
class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
layers = []
layers.append(torch.nn.Conv2d(1, 1, kernel_size=2, stride=1, bias=True, padding=0))
layers.append(torch.nn.ReLU(inplace=True))
layers.append(torch.nn.BatchNorm2d(num_features=1))
self.net = torch.nn.ModuleList(layers)
def forward(self, data):
for layer in self.net:
data = layer(data)
return data
# 2.使用torch.nn.Sequential定义网络。
class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.net = torch.nn.Sequential(
torch.nn.Conv2d(1, 1, kernel_size=2, stride=1, bias=True, padding=0),
torch.nn.ReLU(inplace=True),
torch.nn.BatchNorm2d(num_features=1)
)
def forward(self, data):
return self.net(data)
使用 torch.nn.ModuleList 定义网络,要在forward函数中定义数据经过每一层的顺序,如第 12、13 行。
使用 torch.nn.Sequential 定义网络,直接输入数据无须定义数据怎样经过每一层,如第27行。
2. 创建 LeNet
创建一个包含卷积层、池化层、激活函数层、BN层、Dropout层、全连接层的 LeNet。