李沐_pytorch补充

1、注册带有参数的层时候就要使用nn.Parameter()

self.pos_embedding = nn.Parameter(torch.randn(1, num_patches+1, dim))
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))

self.v=nn.Parameter()
self.v=nn.Linear()
要将左侧的v注册成为参数,右侧就需要进行nn.Parameter这个操作,第二个没有是因为Linear本身就是封装好了的。
这也就意味着,所有有学习参数的层,必须在__init__内部封装好。
2、介绍一下**nn._modules()**层的容器,可以用来复写sequential

class MySequential(nn.Module):
    def __init__(self,*args):
        super().__init__()
        for block in args:
            #._modules是包装层的容器,按顺序输入的插入层
            self._modules[block]=block
    def forward(self,x):
        for block in self._modules.values():
            x=block(x)
        return x
net = MySequential(nn.Linear(20,256),nn.ReLU(),nn.Linear(256,10))
net(x)

3、在前项传播中可以写代码,不仅仅包括一些层的传导,还可以书写代码。

class FixedHiddenMLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.rand_weight = torch.rand((20,20),requires_grad=False)
        self.linear = nn.Linear(20,20)
    def forward(self,x):
        x=self.linear(x)
        x=F.relu(torch.mm(x,self.rand_weight)+1)
        x=self.linear(x)
        while x.abs().sum()>1:
            x/=2
        return x.sum()
net=FixedHiddenMLP()
net(x)

4、参数的访问
函数如图所示:

import torch 
from torch import nn
net = nn.Sequential(nn.Linear(4,8),nn.ReLU(),nn.Linear(8,1))
x=torch.rand(size=(2,4))
net(x)
print(net[2].state_dict())

得到nn.Linear(8,1)的所有参数

print(net[2].bias)

访问nn.Linear(8,1)特定的参数bias

print(net[2].bias.data)

得到参数的数值
5、net.named_parameters()可以得到网络的名字和参数,还可以使用net[i]取得特定块的名字和参数

net.named_parameters()中param是len为2的tuple
param[0]是name,fc1.weight、fc1.bias等
param[1]是fc1.weight、fc1.bias等对应的值
23就是另一组
56又是另一组

6、对指定层进行参数初始化

def init_normal(m):
    if type(m) == nn.Linear:
        nn.init.normal_(m.weight,mean=0,std=0.1)
        nn.init.zeros_(m.bias)
net.apply(init_normal)
net[0].weight.data[0],net[0].bias.data[0]
def init_constant(m):
    if type(m) == nn.Linear:
        nn.init.constant_(m.weight,1)
        nn.init.zeros_(m.bias)
net.apply(init_constant)
net[0].weight.data[0],net[0].bias.data[0]

7、直接对指定的参数进行操作

net[0].weight.data[:] +=1
上一篇:【程序员必会十大算法】之Prim算法


下一篇:XML实现动物园动物添加功能