PyTorch 模型剪枝实例教程一、非结构化剪枝

2022-09-24 09:18:40

目前大部分最先进的(SOTA)深度学习技术虽然效果好,但由于其模型参数量和计算量过高,难以用于实际部署。而众所周知,生物神经网络使用高效的稀疏连接(生物大脑神经网络balabala啥的都是稀疏连接的),考虑到这一点,为了减少内存、容量和硬件消耗,同时又不牺牲模型预测的精度,在设备上部署轻量级模型,并通过私有的设备上计算以保证隐私,通过减少参数数量来压缩模型的最佳技术非常重要

稀疏神经网络在预测精度方面可以达到密集神经网络的水平,但由于模型参数量小,理论上来讲推理速度也会快很多。而模型剪枝是一种将密集神经网络训练成稀疏神经网络的方法。

本文将通过学习Torch官方示例教程,介绍如何通过一个简单的实例教程来进行模型剪枝,实践深度学习模型压缩加速。

相关链接

深度学习模型压缩与加速技术(一):参数剪枝

PyTorch模型剪枝实例教程一、非结构化剪枝

PyTorch模型剪枝实例教程二、结构化剪枝

PyTorch模型剪枝实例教程三、多参数与全局剪枝

1.导包&定义一个简单的网络

import torchfrom torchimport nnimport torch.nn.utils.pruneas pruneimport torch.nn.functionalas F

device= torch.device("cuda"if torch.cuda.is_available()else"cpu")'''搭建类LeNet网络'''classLeNet(nn.Module):def__init__(self):super(LeNet, self).__init__()# 单通道图像输入,5×5核尺寸
        self.conv1= nn.Conv2d(1,3,5)
        self.conv2= nn.Conv2d(3,16,5)
        self.fc1= nn.Linear(16*5*5,120)  
        self.fc2= nn.Linear(120,84)
        self.fc3= nn.Linear(84,10)defforward(self, x):
        x= F.max_pool2d(F.relu(self.conv1(x)),(2,2))
        x= F.max_pool2d(F.relu(self.conv2(x)),2)
        x= x.view(-1,int(x.nelement()/ x.shape[0]))
        x= F.relu(self.fc1(x))
        x= F.relu(self.fc2(x))
        x= self.fc3(x)return x

2.获取网络需要剪枝的模块

model= LeNet().to(device=device)
module= model.conv1print(list(module.named_parameters()))# 6×5×5的weight + 6×1的bias 的参数量print("缓冲区数据",list(module.buffers()))# 缓冲区暂时没有数据

输出:

[('weight', Parameter containing:
tensor([[[[ 0.1473,  0.1251,  0.0492, -0.1375, -0.0781],
          [ 0.0446, -0.1328,  0.0227,  0.0141, -0.1751],
          [ 0.0253,  0.0313,  0.0391,  0.1607, -0.0716],
          [-0.1125, -0.1641,  0.1691,  0.1583,  0.0449],
          [-0.0094, -0.1916,  0.1701,  0.0704,  0.0407]]],


        [[[-0.1945,  0.0709,  0.1071,  0.0038, -0.0686],
          [ 0.0187,  0.0710, -0.0955, -0.0778,  0.1927],
          [ 0.1643,  0.0791,  0.1235,  0.0241, -0.0021],
          [-0.1124,  0.0246, -0.0349, -0.1561,  0.0178],
          [-0.1779,  0.1216,  0.1086, -0.1837,  0.1789]]],


        [[[-0.0051, -0.1969, -0.0155,  0.1890,  0.1977],
          [-0.0654,  0.1219,  0.0849, -0.1937, -0.0933],
          [-0.0409,  0.1344,  0.1688,  0.1917, -0.1727],
          [ 0.1380, -0.1413, -0.1483, -0.0711, -0.0648],
          [-0.1571,  0.0570,  0.1783, -0.0786,  0.1367]]]], requires_grad=True)), ('bias', Parameter containing:
tensor([ 0.0346, -0.1446,  0.0633], requires_grad=True))]
缓冲区数据 []

可以发现,通过.named_parameters()方法,可以得到conv1模块的参数和偏置数据,同时缓冲区Buffer数据为空

这里顺便说下,关于PyTorch中有关Buffer和Paramater的区别

一般来说,Torch模型中需要保存下来的参数包括两种:

  • 一种是反向传播需要被optimizer更新的,称之为 parameter
  • 一种是反向传播不需要被optimizer更新,称之为 buffer

第一种参数我们可以通过model.parameters() 返回;第二种参数我们可以通过model.buffers() 返回。因为我们的模型保存的是state_dict 返回的OrderDict,所以这两种参数不仅要满足是否需要被更新的要求,还需要被保存到OrderDict

3.模块剪枝(核心)

剪枝一个模块,需要三步:

  • step1.在torch.nn.utils.prune中选定一个剪枝方案,或者自定义(通过子类BasePruningMethod)
  • step2.指定需要剪枝的模块和对应的名称
  • step3.输入对应函数需要的参数

这里示例一个非结构化剪枝方法,random_unstructured(),选定conv1模块,剪枝比例为30%

# 这里,选用方案为随机非结构化剪枝module(conv1)中weight的参数,比例为30%
prune.random_unstructured(module,name='weight',amount=0.3)

'修剪的作用是将权重从参数中移除,并用一个名为weight_orig的新参数替换它(即在初始参数名称后面添加“_orig”)。weight_trans存储了张量的未剪枝的版本。bias没有被修剪,所以它会保持不变。我们看看现在module的weight变成啥样了

print(list(module.named_parameters()))

输出:

tensor([ 0.1363, -0.0978,  0.1246], requires_grad=True)), ('weight_orig', Parameter containing:
tensor([[[[-0.1949, -0.1004,  0.1231, -0.1788, -0.0385],
          [-0.1404,  0.1485,  0.1492, -0.0044,  0.1715],
          [ 0.0118, -0.0254,  0.0238,  0.1694, -0.1564],
          [ 0.1296,  0.0766,  0.1456,  0.0181,  0.1586],
          [-0.0531,  0.1709,  0.1242,  0.0671,  0.1864]]],


        [[[-0.0218,  0.0216, -0.1337,  0.0226, -0.0229],
          [ 0.1921,  0.0834, -0.1653,  0.1647,  0.0668],
          [-0.1422, -0.1798,  0.0899,  0.0038,  0.1207],
          [-0.0348, -0.1031, -0.1191,  0.0156, -0.1276],
          [ 0.0353,  0.0265, -0.1072,  0.0520,  0.1278]]],


        [[[-0.0569, -0.0463, -0.0963,  0.0876, -0.1442],
          [ 0.0623,  0.1549, -0.1358,  0.0810,  0.0437],
          [-0.1940, -0.0122,  0.1128,  0.1723, -0.1043],
          [-0.0370,  0.0330, -0.0919, -0.1447,  0.0477],
          [ 0.1211, -0.1251,  0.1661,  0.1127, -0.1026]]]], requires_grad=True))]

可以发现,原始的weight被weight_orig代替,bias保持不变

由上述选择的剪枝方案生成的剪枝掩码被保存为一个名为weight_mask的模块缓冲区(即在初始参数名称后面添加“_mask”)。

print(list(module.buffers()))

输出

[tensor([[[[1., 0., 1., 1., 0.],
          [1., 1., 1., 0., 1.],
          [1., 1., 0., 1., 1.],
          [0., 1., 1., 1., 0.],
          [1., 0., 1., 1., 0.]]],

        [[[1., 1., 1., 0., 1.],
          [1., 0., 1., 1., 1.],
          [1., 0., 0., 1., 1.],
          [0., 1., 1., 1., 1.],
          [1., 1., 0., 0., 1.]]],

        [[[1., 1., 0., 0., 1.],
          [1., 1., 1., 1., 1.],
          [0., 1., 1., 1., 1.],
          [0., 0., 1., 1., 0.],
          [0., 1., 1., 1., 1.]]]])]

可以发现,buffers多出来6×5×5的数据,其中mask中的0代表被剪枝,1代表未被剪枝。实际上就是mask与原始参数进行组合,然后保存在weight中,要注意此时它不再是模型的参数,而只是一个属性。

print(module.weight)

输出

tensor([[[[-3.5525e-02,2.8088e-02,1.0221e-01,6.5053e-02,-1.2882e-01],[1.8767e-01,-1.5429e-01,8.7599e-02,-7.1018e-02,0.0000e+00],[1.7288e-01,0.0000e+00,-1.5061e-01,1.5144e-01,-1.5307e-01],[-7.1318e-03,8.8781e-05,1.7603e-01,3.9326e-02,-1.5911e-01],[-6.2226e-02,-6.4120e-02,-8.2244e-02,-1.1819e-01,-1.7782e-01]]],[[[-0.0000e+00,0.0000e+00,8.4743e-02,0.0000e+00,1.6629e-01],[-0.0000e+00,1.3452e-02,1.4956e-01,-2.5982e-02,1.0650e-01],[1.0035e-01,-0.0000e+00,5.2241e-02,-0.0000e+00,1.9607e-01],[-0.0000e+00,1.8732e-01,6.5181e-02,-0.0000e+00,0.0000e+00],[-1.5288e-01,-1.1800e-01,2.0679e-02,0.0000e+00,-1.7894e-01]]],[[[-4.8152e-02,-0.0000e+00,0.0000e+00,-0.0000e+00,6.8126e-03],[-1.3027e-01,0.0000e+00,-1.6212e-01,0.0000e+00,-8.0152e-02],[0.0000e+00,-6.6218e-02,-0.0000e+00,7.0663e-02,-8.4947e-02],[7.8395e-03,-0.0000e+00,-1.9573e-01,-0.0000e+00,-1.1782e-03],[-3.8539e-02,-0.0000e+00,9.8646e-02,6.0333e-03,5.0500e-02]]]],
       grad_fn=<MulBackward0>)

最后,查看._forward_pre_hooks,当模块被剪枝时,它将为被剪枝相关的参数获取一个forward_pre_hook。在本例中,由于到目前为止我们只删除了名为weight的原始参数,因此只会出现一个hock。

print(module._forward_pre_hooks)# 只有一个hock,即weight

为了完整起见,我们现在删除bias,看看模块的参数、缓冲区、hook和属性是如何变化的。刚使用的是随机剪枝,这里我们用L1范数剪枝bias中最小的1个值。

prune.l1_unstructured(module, name="bias", amount=1)print(list(module.named_parameters()))

输出

[('weight_orig', Parameter containing:
tensor([[[[-0.1782, -0.0187,  0.1452, -0.0208, -0.0617],
          [ 0.0813,  0.0256,  0.1087, -0.1082, -0.1830],
          [-0.0367,  0.0012,  0.0620, -0.1082,  0.0187],
          [ 0.0512, -0.1047,  0.1500,  0.0423,  0.1635],
          [ 0.0557, -0.1944, -0.0240,  0.1285,  0.1127]]],

        [[[ 0.1026,  0.1493, -0.1530, -0.0785, -0.0283],
          [-0.0522, -0.1460,  0.1743,  0.0715, -0.1267],
          [ 0.0432,  0.1593, -0.1477, -0.0808,  0.1490],
          [ 0.0948,  0.1654,  0.1932,  0.1086,  0.0209],
          [ 0.1473, -0.1026, -0.1357,  0.1201,  0.1384]]],

        [[[-0.0339, -0.0469, -0.0912, -0.0095,  0.0836],
          [-0.0617,  0.1394,  0.1318,  0.0559,  0.1921],
          [ 0.1632, -0.1898,  0.0593,  0.1321,  0.1799],
          [ 0.0223, -0.0060,  0.0369,  0.1915,  0.0874],
          [-0.1907, -0.1033,  0.0613,  0.0018, -0.2000]]]], requires_grad=True)), ('bias_orig', Parameter containing:
tensor([ 0.1157,  0.1016, -0.0943], requires_grad=True))]

可以看到,weight替换为weight_orig,bias替换为bias_orig

print(list(module.named_buffers()))print(module.bias)print(module._forward_pre_hooks)

输出

[('weight_mask', tensor([[[[0.,1.,1.,0.,0.],[1.,0.,1.,1.,1.],[1.,0.,1.,1.,1.],[1.,1.,1.,0.,1.],[1.,0.,1.,1.,1.]]],[[[1.,0.,1.,1.,1.],[1.,1.,1.,1.,0.],[1.,1.,1.,0.,1.],[1.,1.,0.,1.,1.],[1.,1.,1.,1.,0.]]],[[[1.,1.,1.,0.,1.],[1.,0.,0.,0.,0.],[1.,0.,0.,1.,0.],[1.,1.,0.,1.,1.],[1.,1.,0.,1.,1.]]]])),('bias_mask', tensor([1.,0.,1.]))]
tensor([0.0827,-0.0000,0.1126], grad_fn=<MulBackward0>)
OrderedDict([(0,<torch.nn.utils.prune.RandomUnstructuredobject at0x000001E77DCF5B88>),(1,<torch.nn.utils.prune.L1Unstructuredobject at0x000001E7739631C8>)])

4.总结

本示例首先搭建了一个类LeNet网络模型,为了进行非结构化剪枝,我们选取了LeNet的conv1模块,该模块参数包含为3×5×5的weight卷积核参数和3×1的bias参数,通过示例,我们利用torch.nn.prune中的剪枝方法,实现了对weight参数进行30%随机非结构化剪枝,以及对bias的L1非结构化剪枝。

本文用到的核心函数方法:

  • module.named_parameters(),需转换为list对其可视化
  • module.buffers(),需转换为list对其可视化
  • module.weight,直接打印模块的weight参数
  • module.bias,直接打印模块的bias参数
  • prune.random_unstructured(),随机非结构化剪枝
  • prune.l1_unstructured(),L1非结构化剪枝

参考:

Pytorch模型中的parameter与buffer

Torch官方剪枝教程

  • 作者:小风_
  • 原文链接:https://blog.csdn.net/qq_33952811/article/details/124346514
    更新时间:2022-09-24 09:18:40