李宏毅机器学习课程梳理【十六】:Network Compression(压缩深度神经网络)

文章目录

摘要

大型机器学习的模型需要一些方法来压缩。

1 Network Compression

在一些移动设备上,如智能手机、智能手表、机器人、无人机飞行器等设备,没有足够的存储空间去使用训练好的模型,因为深度神经网络可能过于庞大。

1.1 Network Pruning

现在训练出来的模型是参数冗余的,在测试时发现有些参数甚至有些神经元根本没有被用到,值始终为0或非常接近于0,所以可以在网络中把它们移除。Network Pruning(网络剪枝)的步骤:对一个训练好的模型,评估参数和神经元的重要性,参数的评估方法为计算L1、L2的值、神经元的评估方法为观察给定的数据集中神经元的输出,将重要性评估结果排序,移除不重要的参数和神经元,移除后模型的表现会变差一点。Fine-tune是用原训练集数据训练剪枝后的模型,得到的模型能克服剪枝带来的变差影响。
李宏毅机器学习课程梳理【十六】:Network Compression(压缩深度神经网络)
恢复影响以后,判断网络是否足够小,如果不满意现在网络的大小,则重复以上步骤,继续剪枝,直到满足条件。

1.2 Why Pruning?

为什么要训练一个大的网络,再剪枝呢?何不直接训练一个小的网络?研究表明,小的网络不容易用梯度找到全局最低点,而只要网络足够大,一定能移动到全局最低点。有学者提出Lottery Ticket Hypothesis大乐透理论,指出大的网络容易训练,是因为大的网络存在更多种小网络,也就存在更大的训练成功的可能性。众多小网络中,有一个能够训练成功,整个大网络就能训练成功。研究过程如图2所示。
李宏毅机器学习课程梳理【十六】:Network Compression(压缩深度神经网络)
训练大网络时随机初始化参数,用红色表示这些参数,训练成功后用紫色表示这些参数,然后移除掉未使用的神经元和参数,得到小网络。如果试图重新随机初始化小网络的参数,则最终训练失败。如果使用红色的对应参数,则能训练成功。于是给出大网络容易训练现象的一种解释。

Rethinking the Value of Network Pruning在CIFAR-10和ImageNet数据集上训练不同类型的网络,做网络剪枝,与直接随机初始化参数训练小网络对比,发现直接训练小网络的准确性还略高,结论是小网络可以直接训练起来。这一结论与Lottery Ticket Hypothesis的结论相反,感兴趣的可以去关注。

1.3 Weight Pruning V.S Neuron Pruning

在实做时,移除参数会形成不规则网络,导致无法用GPU加速矩阵计算,也无法在一些平台中实施,如图3所示。
李宏毅机器学习课程梳理【十六】:Network Compression(压缩深度神经网络)
因为不规则的网络受很多计算限制,剪枝后的小网络训练起来会比较慢。
但是Weight Pruning的效果很强,比如有实验移除95%的参数,准确度仅降低1%~2%
Neuron Pruning将无用的神经元和它前后连接的参数都移除。对比起来,Neuron Pruning训练得更快,实做更简单。

2 Knowledge Distillation

Knowledge Distillation的工作是训练一个大网络Teacher Net,将这个大网络的输入与输出作为小网络Student Net的训练集数据,让小网络去学习大网络的运作,并不是让小网络直接学有标签的原始训练集数据。也就是训练小网络,使小网络的输入与大网络的输入相同时,大小网络的输出的交叉熵最小。
这样,小网络可以学习到“1”与“7”与“9”很像等知识,比只学习训练集数据的效果好。

2.1 Application of Ensemble

集成学习的一般结构是先产生一组“个体学习器”,再用某种策略将它们结合起来。而Knowledge Distillation可作为此种策略将它们结合起来。训练N个集成学习的网络,将N个网络的输出求平均作为最终结果,让Student Net的输入与N个网络相同、输出与平均结果计算交叉熵、使交叉熵最小化。一个Student Net就可以达到N个网络的效果。

2.2 Temperature

多分类问题通过Softmax层得到对应分类的概率,在Knowledge Distillation实做中,输出的表达式增加参数T,用以实现小网络能准确地学到大网络的输出,学到比训练集数据更多的信息。
李宏毅机器学习课程梳理【十六】:Network Compression(压缩深度神经网络)

3 Parameter Quantization

参数量子化也是压缩网络的一个方法,下面介绍它有几种实现方法。

  1. 每一个值都使用更少的比特去表示
  2. Weight clustering,将网络中的参数分群,存储时只存cluster的id号和cluster的值,将网络压缩成一个表格。如果分成四群,则只占2比特(00,01,10,11四种)。此种方法会损失一些精度,各个cluster的值都是平均值。Weight clustering的示意图如图5所示。
  3. 在2的基础上进一步压缩,使用赫夫曼编码的思想,比较常出现的cluster用比较短的比特来表示、不常出现的cluster用比较长的比特表示。
上一篇:compression-webpack-plugin 开启gzip vue


下一篇:李宏毅-Network Compression课程笔记