BN剪枝实践

1.github代码实践

源代码是lua脚本语言,下载th之后运行

th main.lua -netType vgg -save vgg_cifar10/ -S 0.0001,报错:

BN剪枝实践

 试看看安装lua:

报错了,参考这篇文章:

ubuntu18.04安装lua的步骤 以及 出现的问题_weixin_41355132的博客-****博客

问题解决,安装成功:

BN剪枝实践

情况并没有好转,出现相同的报错(想了一下发现可能前面安装过lua了,好像跟torch啥的包有关):

BN剪枝实践

可能是这个东西没安装:

BN剪枝实践

进到这个里面的INSTALL,发现可能lua版本的torch,cuda,cudnnv4之类的东西没有装,在装torch的时候发现确实下载了Lua相关的组件,但是不知道cuda和cudnn要不要重新安装,cuda看起来安装过程没有什么特别的,cudnn的下载好像需要登录nvidia比较烦人

安装torch:

BN剪枝实践

torch安装好像失败了:

BN剪枝实践

发现有第三方代码用pytorch实现,准备转战:

BN剪枝实践

https://github.com/foolwood/pytorch-slimming

先把torch卸载了:

BN剪枝实践

原Torch实现代码在这里:

BN剪枝实践

后来下载的另一个git在这里:

BN剪枝实践

pytorch实现的剪枝:

BN剪枝实践

 运行main.py,成功下载CIFAR10后报错

BN剪枝实践

是打印结果的这部分有错误(最后一行):(解决,把loss.data[0]的[0]删掉就可以了)

BN剪枝实践

但是突然调试不了了(换了电脑试还是不行,可能试代码的问题,没法调试,一调试就卡死,然后),先注释掉运行之后,显卡转起来了,然后报错:

BN剪枝实践

 发现试test部分出了问题:

test_loss += F.cross_entropy(output, target, size_average=False).data # sum up batch loss   .data[0]

这里把.data[0]给改成了.data就跑起来了 ,到最大epoch的0.5到0.75的时候调整学习率

权重和best权重保存在这里:

BN剪枝实践

开始训练

batch-size=100,每100batch报告一次,10000张图片输出一次,总共50000张图片,epoch=160

GPU:GTX TITAN,12G显存

显存占用:7566MiB / 12187MiB

内存占用:下图

BN剪枝实践

BN剪枝实践

BN剪枝实践

160epoch结果:

BN剪枝实践

精度大概是93.5%

运行完了之后:

BN剪枝实践

BN剪枝实践

运行的时候显存占用约:7560-1358=6202  MiB(batch-size=100)(CIFAR-10)

 查看参数量:

使用第三方工具torchstat

from torchstat import stat
from vgg import vgg


model = vgg()
stat(model,(3,32,32))

直接算了原来vgg的参数量和运算量 :

BN剪枝实践

FLOPS:注意全大写,是floating point operations per second的缩写,意指每秒浮点运算次数,理解为计算速度。是一个衡量硬件性能的指标。

FLOPs:注意s小写,是floating point operations的缩写(s表复数),意指浮点运算数,理解为计算量。可以用来衡量算法/模型的复杂度。
MACCs:是(multiply-accumulate operations),也叫MAdds,指乘-加操作(点积运算),理解为计算量, 大约是 FLOPs 的一半。

BN剪枝实践

开始稀疏化训练:

s=0.0001

BN剪枝实践

best pred=0.9376000165939331

剪枝

BN剪枝实践

BN剪枝实践

BN剪枝实践 stat参数量:

BN剪枝实践

压缩率:2.88/3.59 = 80.22284122562674%

Flops压缩率:211.17/399.17 = 52.9022722148458%

测验参数量代码:

from torchstat import stat
from vgg import vgg
import torch



checkpoint = torch.load('pruned.pth.tar')   # args.refine存储的是稀疏化训练之后的权重文件
model = vgg(cfg=checkpoint['cfg'])
stat(model,(3,32,32))

Fine-turn:

精度:93.3%

几个小问题:

上一篇:stat和lstat函数


下一篇:K8s ❉ 报错cannot stat ‘/etc/kubernetes/admin.conf’