class CONCAT_CNN(nn.Module):
def __init__(self):
super(CONCAT_CNN, self).__init__()
self.conv1_1 = nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1)
def forward(self, x):
conv1_1 = F.relu(self.conv1_1(x))
return conv1_1
def initialize(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
# nn.init.normal_(m.weight.data) # normal: mean=0, std=1
nn.init.kaiming_normal_(m.weight.data)
m.bias.data.fill_(0)
net=CONCAT_CNN()
net.initialize()
print(net)
print("Total number of paramerters in networks is {} ".format(sum(x.numel() for x in net.parameters())))
正态(normal)均匀( uniform)
(适用于RELU)kaiming正态、kaiming均匀
nn.init.kaiming_normal_(m.weight.data)
nn.init.kaiming_uniform_(m.weight.data)
均匀分布、正态分布
nn.init.uniform_(m.weight.data)# normal: mean=0, std=1
nn.init.normal_(m.weight.data,std=0.01)
(适用于Sigmoid,Tanh)Xavier正态、均匀
nn.init.xavier_normal_(m.weight.data)
nn.init.xavier_uniform_(m.weight.data)