动态计算图
1. 计算图
-
计算图是用来描述运算的有向无环图
-
计算图有两个主要元素:结点(Node)和边(Edge)
-
结点表示数据,如向量,矩阵,张量。边表示运算,如加减乘除卷积等
-
用计算图表示:
y = ( x + w ) ∗ ( w + 1 ) y=(x+w)*(w+1) y=(x+w)∗(w+1)
采用运算法的优势是令梯度的计算更加方便,y对w求导的过程如下:
y对w求导一共包含两项内容,分别是y对a求导和y对b求导。
2.PyTorch的动态图
根据计算图搭建方式的不同,可将计算图分为动态图和静态图。
动态图
- 运算与搭建同时进行,典型代表是pytorch
- 特点:灵活、易调节
先创建原始数据,之后执行第一个乘法操作,然后再执行另一个乘法操作,之后执行加法操作,接着执行一个激活函数,最后计算一个loss,有了loss之后,执行梯度反向传播。
静态图
- 先搭建图,后运算,典型代表TensorFlow
- 特点:高效、不灵活
tensor存进去之后,无法改变tensor的流动方向
3.自动微分变量
PyTorch使用自动微分变量实现动态计算图。在PyTorch0.4中自动微分变量已经与张量完全合并。即,任意一个张量都是一个自动微分变量。
采用自动微分计算时,系统自动构建计算图,即,存储计算路径。可通过访问自动微分变量的grad_fn来获得计算图中的上一个节点,可知哪个运算导致此自动微分变量出现。每个节点的grad_fn就是计算图中的箭头。可使用grad_fn回溯来重构整个计算图。
最后进行反向传播算法时,需要计算计算图中每个变量节点的梯度值(grandient,即该变量需要被更新的增量)。我们只需要调用.backward()函数即可算出所有变量的梯度信息,并将叶节点的导数值存储在.grad中。