1、使用nn.Sequential()建立模型的三种方式
import torch as t from torch import nn # Sequential的三种写法 net1 = nn.Sequential() net1.add_module('conv', nn.Conv2d(3, 3, 3)) # Conv2D(输入通道数,输出通道数,卷积核大小) net1.add_module('batchnorm', nn.BatchNorm2d(3)) # BatchNorm2d(特征数) net1.add_module('activation_layer', nn.ReLU()) net2 = nn.Sequential(nn.Conv2d(3, 3, 3), nn.BatchNorm2d(3), nn.ReLU() ) from collections import OrderedDict #注意字典的key不能重复 net3 = nn.Sequential(OrderedDict([ ('conv1', nn.Conv2d(3, 3, 3)), ('bh1', nn.BatchNorm2d(3)), ('al', nn.ReLU()) ])) print('net1', net1) print('net2', net2) print('net3', net3) # 可根据名字或序号取出子module print(net1.conv, net2[0], net3.conv1)
输出:
net1 Sequential( (conv): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1)) (batchnorm): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (activation_layer): ReLU() ) net2 Sequential( (0): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1)) (1): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): ReLU() ) net3 Sequential( (conv1): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1)) (bh1): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (al): ReLU() ) Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1)) Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1)) Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1))
2、使用nn.ModuleList建立模型。
class MyModule(nn.Module): def __init__(self): super(MyModule, self).__init__() self.list = [nn.Linear(3, 4), nn.ReLU()] self.module_list = nn.ModuleList([nn.Conv2d(3, 3, 3), nn.ReLU()]) def forward(self): pass model = MyModule() print(model)
输出:
MyModule( (module_list): ModuleList( (0): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1)) (1): ReLU() ) )
3、二者结合构造更复杂的网络模型
例如cisnet中的Decoder模型
class Decoder(nn.Module): num_quan_bits = 4 def __init__(self, feedback_bits): super(Decoder, self).__init__() self.feedback_bits = feedback_bits self.dequantize = DequantizationLayer(self.num_quan_bits) self.multiConvs = nn.ModuleList() self.fc = nn.Linear(int(feedback_bits / self.num_quan_bits), 768) self.out_cov = conv3x3(2, 2) self.sig = nn.Sigmoid() for _ in range(3): self.multiConvs.append(nn.Sequential( conv3x3(2, 8), nn.ReLU(), conv3x3(8, 16), nn.ReLU(), conv3x3(16, 2), nn.ReLU())) def forward(self, x): out = self.dequantize(x) out = out.contiguous().view(-1, int(self.feedback_bits / self.num_quan_bits)) #需使用contiguous().view(),或者可修改为reshape out = self.sig(self.fc(out)) out = out.contiguous().view(-1, 2, 24, 16) #需使用contiguous().view(),或者可修改为reshape for i in range(3): residual = out out = self.multiConvs[i](out) out = residual + out out = self.out_cov(out) out = self.sig(out) out = out.permute(0, 2, 3, 1) return out