nn.Module 函数详解
nn.Module是所有网络模型结构的基类,无论是pytorch自带的模型,还是要自定义模型,都需要继承这个类。这个模块包含了很多子模块,如下所示,_parameters存放的是模型的参数,_buffers也存放的是模型的参数,但是是那些不需要更新的参数。带hook的都是钩子函数,详见钩子函数部分。
self._parameters = OrderedDict() self._buffers = OrderedDict() self._non_persistent_buffers_set = set() self._backward_hooks = OrderedDict() self._is_full_backward_hook = None self._forward_hooks = OrderedDict() self._forward_pre_hooks = OrderedDict() self._state_dict_hooks = OrderedDict() self._load_state_dict_pre_hooks = OrderedDict() self._modules = OrderedDict()此外,每一个模块还内置了一些常用的方法来帮助访问和操作网络。
load_state_dict() #加载模型权重参数 parameters() #读取所有参数 named_parameters() #读取参数名称和参数 buffers() #读取self.named_buffers中的参数 named_buffers() #读取self.named_buffers中的参数名称和参数 children() #读取模型中,所有的子模型 named_children() #读取子模型名称和子模型 requires_grad_() #设置模型是否开启梯度反向传播Parameter类
Parameter是Tensor子类,所以继承了Tensor类的属性。例如data和grad属性,可以根据data来访问参数数值,用grad来访问参数梯度。
weight_0 = nn.Parameters(torch.randn(10,10)) print(weight_0.data) print(weight_0.grad)定义变量的时候,nn.Parameter会被自动加入到参数列表中去
class MyModel(nn.Module): def __init__(self): super(MyModel,self).__init__() self.weight1 = nn.Parameter(torch.randn(10,10)) self.weight2 = torch.randn(10,10) def forward(self,x): pass model = MyModel() for name,param in model.named_parameters(): print(name) output: weight1ParameterList
接定义成Parameter类外,还可以使用ParameterList和ParameterDict分别定义参数的列表和字典。ParameterList接收一个Parameter实例的列表作为输入然后得到一个参数列表,使用的时候可以用索引来访问某个参数,另外也可以使用append和extend在列表后面新增参数。
params = nn.ParameterList( [nn.Parameter(torch.randn(10,10)) for i in range(5)] ) params.append(nn.Parameter(torch.randn(3,3)))ParameterDict
可以像添加字典数据那样添加参数
params = nn.ParameterDict({ 'linear1':nn.Parameter(torch.randn(10,5)), 'linear2':nn.Parameter(torch.randn(5,2)) })模型构建
使用Sequential构建模型
# 写法一 net = nn.Sequential( nn.Linear(num_inputs, 1) # 此处还可以传入其他层 ) # 写法二 net = nn.Sequential() net.add_module('linear', nn.Linear(num_inputs, 1)) # net.add_module ...... # 写法三 from collections import OrderedDict net = nn.Sequential(OrderedDict([ ('linear', nn.Linear(num_inputs, 1)) # ...... ])) print(net)自定义模型
- 无参数模型
下面是一个展开操作,比如将2维图像展开成一维
class Flatten(nn.Module): def __init__(self): super(Flatten,self).__init__() def forward(self,input): return input.view(input.size(0),-1)- 有参数模型
自定义一个Linear层
class MLinear(nn.Module): def __init__(self,input,output): super(MyLinear,self).__init__() self.w = nn.Parameter(torch.randn(input,output)) self.b = nn.Parameter(torch.randn(output)) def foward(self,x): x = self.w @ x + self.b return x- 组合模型
ModuleList & ModuleDict
ModuleList 和 ModuleDict都是继承与nn.Module, 与Seuqential不同的是,ModuleList 和 ModuleDict没有自带forward方法,所以只能作为一个模块和其他自定义方法进行组合。下面是使用示例:
class MyModuleList(nn.Module): def __init__(self): super(MyModuleList, self).__init__() self.linears = nn.ModuleList( [nn.Linear(10, 10) for i in range(3)] ) def forward(self, x): for linear in self.linears: x = linear(x) return x class MyModuleDict(nn.Module): def __init__(self): super(MyModuleDict, self).__init__() self.linears = nn.ModuleDict({ "linear1":nn.Linear(10,10), "linear2":nn.Linear(10,10) }) def forward(self, x): x = self.linears["linear1"](x) x = self.linears["linear2"](x) return x