pytorch模拟tensorflow的weights输入(适用元学习)

元学习中由于需要二次求导,因此使用tensorflow的形式实现是最方便的

from torch.autograd import Variable
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from collections import OrderedDict
from model_meta import common

class g(nn.Module):
    def __init__(self):
        super(g, self).__init__()
        self.k1 = nn.Conv2d(in_channels=2, out_channels=2, kernel_size=3, padding=1, bias=True)
        self.bn = nn.BatchNorm2d(2)
        self.act = nn.LeakyReLU(0.1)
        self.k = nn.Conv2d(in_channels=2, out_channels=2, kernel_size=3, padding=1, bias=True)
        self.ad = nn.AdaptiveAvgPool2d(1)
        self.bn1 = nn.BatchNorm2d(2)
        self.ln =nn.Linear(2,2)


    def forward(self, z, weights,c):
        # a, b = torch.topk(z, 2, dim=-1, largest=True, sorted=True)
        # return a
        # print(weights)
        # print(weights["k1.weight"],weights["k1.weight"].shape)
        # print("Sadkasjhd")
        # print(self.bn.running_mean,self.bn.training , self.bn.track_running_stats)
        if c:
            z = common.conv2d(z, weights["k1.weight"], bias=weights["k1.bias"])
            # z = F.conv2d(z,weights["k1.weight"],stride=1, padding=1)
            # z = common.batchnorm(running_mean=None, running_var=None, training)
            # print(self.bn.running_mean)
            z = common.batchnorm(z, weight=weights["bn.weight"], bias=weights["bn.bias"], running_mean=self.bn.running_mean, running_var=self.bn.running_var, training=self.training)
            z = F.leaky_relu(z, self.act.negative_slope)
            # print("negative_slope",self.act.negative_slope)
            z = common.conv2d(z, weights["k.weight"], bias=weights["k.bias"])
            z = common.batchnorm(z, weight=weights["bn1.weight"], bias=weights["bn1.bias"],
                                 running_mean=self.bn1.running_mean, running_var=self.bn1.running_var,
                                 training=self.training)
            z = F.leaky_relu(z, self.act.negative_slope)
            z = self.ad(z).squeeze(-1).squeeze(-1)
            z = common.linear(z,weights["ln.weight"],weights["ln.bias"])
        else:
            z = self.k1(z)
            z = self.bn(z)
            z = self.act(z)
            z = self.k(z)
            z = self.bn1(z)
            z = self.act(z)
            z = self.ad(z).squeeze(-1).squeeze(-1)
            z = self.ln(z)
        return z

net =g().eval()
c = 2
h = 5
w = 5
num=255.
weights = OrderedDict(
        (name, param ) for (name, param) in net.named_parameters())
print(weights)
z = torch.rand(1, c , h , w).float().view(1, c, h, w)*num
k=net(z,weights,1)
print("***********V1************")
print(k)
print("************V2***********")
k=net(z,weights,0)
print(k)

结果:

OrderedDict([('k1.weight', Parameter containing:
tensor([[[[-0.2247, -0.1581, -0.0898],
          [ 0.0360,  0.0034, -0.0012],
          [ 0.1881, -0.2175, -0.1558]],

         [[ 0.2345, -0.2052,  0.2291],
          [ 0.1458, -0.0778, -0.0761],
          [ 0.1458,  0.1497,  0.1909]]],


        [[[-0.1647,  0.0314, -0.2093],
          [-0.0598,  0.0189, -0.2058],
          [-0.2004,  0.0625,  0.1661]],

         [[-0.1550,  0.2228,  0.2277],
          [-0.1925,  0.1914, -0.1848],
          [-0.0585,  0.2001,  0.1779]]]], requires_grad=True)), ('k1.bias', Parameter containing:
tensor([ 0.1321, -0.2026], requires_grad=True)), ('bn.weight', Parameter containing:
tensor([0.6140, 0.3376], requires_grad=True)), ('bn.bias', Parameter containing:
tensor([0., 0.], requires_grad=True)), ('k.weight', Parameter containing:
tensor([[[[ 0.1454, -0.1201, -0.0085],
          [ 0.0584,  0.1009,  0.1226],
          [-0.1576,  0.1127, -0.0389]],

         [[ 0.0483,  0.0248,  0.0990],
          [-0.2266, -0.1486, -0.0324],
          [-0.0946,  0.0063,  0.1903]]],


        [[[ 0.0238,  0.0458, -0.1987],
          [-0.1096, -0.1962, -0.1864],
          [-0.1547, -0.0741,  0.1740]],

         [[-0.0820,  0.2186,  0.0900],
          [-0.0165,  0.0776,  0.0946],
          [ 0.0113, -0.2241,  0.2184]]]], requires_grad=True)), ('k.bias', Parameter containing:
tensor([-0.1629, -0.0589], requires_grad=True)), ('bn1.weight', Parameter containing:
tensor([0.5731, 0.4600], requires_grad=True)), ('bn1.bias', Parameter containing:
tensor([0., 0.], requires_grad=True)), ('ln.weight', Parameter containing:
tensor([[ 0.3224, -0.5423],
        [ 0.3757,  0.0196]], requires_grad=True)), ('ln.bias', Parameter containing:
tensor([-0.0116, -0.4640], requires_grad=True))])
***********************
tensor([[0.1756, 0.0552]], grad_fn=<AddmmBackward>)
***********************
tensor([[0.1756, 0.0552]], grad_fn=<AddmmBackward>)

上一篇:我的MYSQL学习心得(十四) 备份和恢复


下一篇:20191317王鹏宇第四章学习笔记