resnet的特点在于残差结构,可以有效的防止梯度消失。对x+f(x)求导,可以得到1+f’(x),所以梯度是常数。从工程实现上,可以得到的启示包括,目前的网络设计的基本卷积结构都是由一个卷积层+bn层+激活函数 构成的小模块。其他参数都是调优的结果。本文resnet18的实现是基于paddle2.2的版本
import paddle
import paddle.nn as nn
#paddle.set_device('cpu')
class Identity(nn.Layer):
def __init_(self):
super().__init__()
def forward(self, x):
return x
class Block(nn.Layer):
def __init__(self, in_dim, out_dim, stride):
super().__init__()
## 补充代码
self.conv1 = nn.Conv2D(in_dim, out_dim, 3, stride, 1)
self.bn1 = nn.BatchNorm2D(out_dim)
self.conv2 = nn.Conv2D(out_dim, out_dim, 3, 1, 1)
self.bn2 = nn.BatchNorm2D(out_dim)
self.relu = nn.ReLU()
if stride != 1:
downsample = []
downsample.append(nn.Conv2D(in_dim, out_dim, 3, stride, 1))
downsample.append(nn.BatchNorm2D(out_dim))
self.downsample = nn.Sequential(*downsample)
else:
self.downsample = Identity()
def forward(self, x):
## 补充代码
h = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
identity = self.downsample(h)
out += identity
out = self.relu(out)
return out
class ResNet18(nn.Layer):
def __init__(self, in_dim=64, num_classes=1000):
super().__init__()
## 补充代码
self.in_dim = in_dim
self.conv = nn.Conv2D(3, 64, 7, 2)
self.norm = nn.BatchNorm2D(64)
self.relu = nn.ReLU()
self.max_pooling = nn.MaxPool2D(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(64, 2, 1)
self.layer2 = self._make_layer(128, 2, 2)
self.layer3 = self._make_layer(256, 2, 2)
self.layer4 = self._make_layer(512, 2, 2)
self.avgpool = nn.AdaptiveAvgPool2D((1, 1))
self.fc = nn.Linear(512 , num_classes)
def _make_layer(self, out_dim, n_blocks, stride):
## 补充代码
layers = []
layers.append(Block(self.in_dim, out_dim, 2))
self.in_dim = out_dim
for _ in range(1, 2):
layers.append(Block(self.in_dim, out_dim, 1))
return nn.Sequential(*layers)
def forward(self, x):
## 补充代码
x = self.conv(x)
x = self.norm(x)
x = self.relu(x)
x = self.max_pooling(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = paddle.flatten(x, 1)
x = self.fc(x)
return x
def main():
model = ResNet18()
print(model)
x = paddle.randn([2, 3, 32, 32])
out = model(x)
print(out.shape)
if __name__ == "__main__":
main()