pytorch(11)模型创建步骤与nn.Module

模型创建与nn.Module

  1. 网络模型创建步骤

    nn.Module
graph LR
模型 --> 模型创建
模型创建 --> 构建网络层
构建网络层 --> id[卷积层,池化层,激活函数层]
模型 --> 权值初始化
权值初始化 --> id1[Xavier,Kaiming,均匀分布,正太分布]
模型创建 --> 拼接网络层
拼接网络层 --> id2[LeNet,AlexNet,ResNet]

LeNet

Conv1 --> pool1 --> Conv2 --> pool2 --> fc1-->fc2 --> fc3

模型创建步骤:

graph LR
A[模型构建两要素] --> B[构建子模块]
B[构建子模块] --> C["__init__()"]
A[模型构建两要素] --> E[拼接子模块]
E[拼接子模块] --> F["forwar()"]
  1. nn.Module属性

    在模型模块中,有一个非常的重要概念是nn.Module,所有的模型、网络层都会继承nn.Module类的
graph LR
A["torch.nn"] --> B["nn.Parameter"]
B["nn.Parameter"] --> D["张量子类,表示可学习参数,如weights,bias"]
A["torch.nn"] --> E["nn.Module"]
E["nn.Module"] --> F["所有网络层基类,管理网络属性"]
A["torch.nn"] --> G["nn.functional"]
G["nn.functional"] --> H["函数具体实现,如卷积,池化,激活函数等"]
A["torch.nn"] --> I["nn.init"]
I["nn.init"] --> K["参数初始化方法"]

nn.Module八个重要的属性,用于管理整个模型

  • parameters:存储管理nn.Parameter类,例如权值、偏置等参数
  • modules:存储管理nn.Module类。例如LeNet,它会构建它的子模块,卷积层、池化层。LeNet的modules示例,就会存储它的卷积层、池化层。
  • buffers:存储管理缓冲属性,如BN层中的running_mean
  • ***_hooks:存储管理钩子函数
self._parameters = OrderedDict()
self._buffers = OrderedDict()
self._backward_hooks = OrderedDict()
self._forward_hooks = OrderedDict()
self._forward_pre_hooks = OrderedDict()
self._state_dict_hooks = OrderedDict()
self._load_state_dict_pre_hooks = OrderedDict()
self._modules = OrderedDict()

net = LeNet(classes=2)

  1. 此时调用lenet.py中的class LeNet(nn.Module)的__init__,继承super(LeNet,self).__init__方法,这是一个module的方法,初始化parameters,buffers等参数。因此首先将LeNet作为一个module初始化完成。
  2. self.conv1 = nn.Conv2d(3, 6, 5)。此时调用conv中的Conv2d类,继承的是_ConvNde类,使用其__init__方法,初始化参数(通道数、输出通道数、padding等),super的初始化方法,即_ConvNd类,此类继承Module类,因此也是一个module类,进行初始化。进入它的super方法,最终也是Module.
  3. self.conv1 = nn.Conv2d(3, 6, 5)。再调用赋值的方法时并不是直接赋值,而是在判断是否是参数,如果是参数,那就赋值给parameter。如果不是参数,如果是module,那就把value的值赋给modules。如果是buffers,那就把value的值赋给buffers.
  4. self.fc1 = nn.Linear(1655, 120)。调用Linear的类,这个类也是继承的Module类,super(Linear,self)类,也是一个Module的类。Linear的方法中的赋值都用到了module.py中的setattr方法,用来判断是参数还是模型

nn.Module

一个module可以包含多个子module

一个module相当于一个运算,必须实现forward()函数

每个modeule都有8个字典管理它的属性

pytorch

数据模块,将数据转换为张量形式输入模型,在深度学习模型中,对输入的张量进行复杂的数学运算,进行分类、分割、目标检测的输出。

上一篇:oracle数据库如何打印九九乘法表


下一篇:测试那些事儿—Linux搭建环境基础步骤