训练
分为两个网络:
Embedding层和GNN度量层
EmbeddingOmniglot
omniglot
EmbeddingOmniglot(
(conv1): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), bias=False)
(bn3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv4): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), bias=False)
(bn4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(fc_last): Linear(in_features=576, out_features=64, bias=False)
(bn_last): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
MetricNN
MetricNN(
(gnn_obj): GNN_nl_omniglot(
(layer_w0): Wcompute(
(conv2d_1): Conv2d(84, 168, kernel_size=(1, 1), stride=(1, 1))
(bn_1): BatchNorm2d(168, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2d_2): Conv2d(168, 126, kernel_size=(1, 1), stride=(1, 1))
(bn_2): BatchNorm2d(126, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2d_3): Conv2d(126, 84, kernel_size=(1, 1), stride=(1, 1))
(bn_3): BatchNorm2d(84, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2d_4): Conv2d(84, 84, kernel_size=(1, 1), stride=(1, 1))
(bn_4): BatchNorm2d(84, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2d_last): Conv2d(84, 1, kernel_size=(1, 1), stride=(1, 1))
)
(layer_l0): Gconv(
(fc): Linear(in_features=168, out_features=48, bias=True)
(bn): BatchNorm1d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(layer_w1): Wcompute(
(conv2d_1): Conv2d(132, 264, kernel_size=(1, 1), stride=(1, 1))
(bn_1): BatchNorm2d(264, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2d_2): Conv2d(264, 198, kernel_size=(1, 1), stride=(1, 1))
(bn_2): BatchNorm2d(198, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2d_3): Conv2d(198, 132, kernel_size=(1, 1), stride=(1, 1))
(bn_3): BatchNorm2d(132, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2d_4): Conv2d(132, 132, kernel_size=(1, 1), stride=(1, 1))
(bn_4): BatchNorm2d(132, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2d_last): Conv2d(132, 1, kernel_size=(1, 1), stride=(1, 1))
)
(layer_l1): Gconv(
(fc): Linear(in_features=264, out_features=48, bias=True)
(bn): BatchNorm1d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
(w_comp_last): Wcompute(
(conv2d_1): Conv2d(180, 264, kernel_size=(1, 1), stride=(1, 1))
(bn_1): BatchNorm2d(264, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(dropout): Dropout(p=0.3, inplace=False)
(conv2d_2): Conv2d(264, 198, kernel_size=(1, 1), stride=(1, 1))
(bn_2): BatchNorm2d(198, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2d_3): Conv2d(198, 132, kernel_size=(1, 1), stride=(1, 1))
(bn_3): BatchNorm2d(132, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2d_4): Conv2d(132, 132, kernel_size=(1, 1), stride=(1, 1))
(bn_4): BatchNorm2d(132, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2d_last): Conv2d(132, 1, kernel_size=(1, 1), stride=(1, 1))
)
(layer_last): Gconv(
(fc): Linear(in_features=360, out_features=20, bias=True)
(bn): BatchNorm1d(20, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
)
)