在pytorch中进行预训练模型的加载和模型的fine-tune操作

2022-10-13 09:59:32

联系方式:
e-mail: FesianXu@163.com
QQ: 973926198
github: https://github.com/FesianXu

文章目录

  • 基模型参数加载
       从持久化模型开始
       加载模型
       部分加载模

  • 模型Fine-Tune

  • 给每一层或者每个模型设置不同的学习率

  • Pytorch内置的模型

  • Reference

在使用pytorch的时候,经常有需要使用一些通用的模型模块作为子模块,比如著名的resnet,densenet,alexnet,inception等等,在使用这些模型的时候,通常希望可以加载该模型在别的数据集(如ImageNet)上进行训练后的权值参数,以便加速整个模型训练过程[1]。在此为了简便,称之为这些预训练模型为基模型,加载基模型参数这个过程按照需求大致可以分为两类:

  1. 整个模型的预训练参数加载
  2. 部分模型的预训练参数加载
    在**对模型的参数进行fine-tune[**2-3]的时候,按照需求也可以大致分为两类:
  3. 固定整个基模型的参数,调节其他模型的参数
  4. 固定部分基模型的参数,调节其他模型的参数

1 基模型参数加载

1.1 从持久化模型开始

在pytorch中,保存一个模型的参数特别容易,用torch.save()即可,例如:

model= CNNNet(params)# 作者自定义的模型,没提供源码
opt= torch.optim.Adam(model.parameters(), lr=1e-4)
model.train()# here we train the models, skip these codes
saved_dict={'model': model.state_dict(),'opt': opt.state_dict()}
torch.save(saved_dict,'./model.pth.tar')

我们发现,torch.save()保存的是一个字典,其中的keys可以自定义。这里有一点要注意的是,如果你用的优化器是例如Adam优化器[5-6]这类内部有参数需要持久化的,最好也将其保存下来

1.2 加载模型

如果是需要加载整个模型,直接用torch.load()和model.load_state_dict()即可,如:

model= CNNNet(params)
opt= torch.optim.Adam(model.parameters(), lr=1e-4)# yes you also need to define the model and optimizer
checkpoint= torch.load('./model.pth.tar')# here, checkpoint is a dict with the keys you defined before
model.load_state_dict(checkpoint['model'])
opt.load_state_dict(checkpoint['opt'])

这个过程中torch.load()只是负责读取模型参数,而用model.load_state_dict()进行加载,这个加载是按照名字进行索引的,如果名字对不上或者是参数的形状,类型对不上,就会报错。我们可以打印出其名字进行观察,如:

for namein checkpoint['model'].keys():print(name)

输出如:

# stgcn是模型的名字,具体到不同的模型名字可能不同# weight_model可能是个调用了子模块A的另一个模块B的属性# conv_models是子模块A中调用了卷积与BN操作的属性
stgcn.weight_model.conv_models.0.conv_layer.weight
stgcn.weight_model.conv_models.0.conv_layer.bias
stgcn.weight_model.conv_models.1.conv_layer.weight
stgcn.weight_model.conv_models.1.conv_layer.bias
stgcn.weight_model.conv_models.1.batch_norm.weight
stgcn.weight_model.conv_models.1.batch_norm.bias

如果定义的模型(网络结构)和持久化的模型的参数名,形状,类型(网络参数)都能完全符合,就能正确加载。同时还注意到,变量的名字是由pytorch自行命名的,其命名根据就是你的变量名字,比如:

import torch.nnas nnclassmodel_A(nn.Module):def__init__(self):super().__init__()
        self.fc= nn.Linear(10,10)classmodel_B(nn.Module):def__init__(self):super().__init__()
        self.sub_model= model_A()
        self.fc= nn.Linear(10,1)
model= model_B()

那么如果你打印prin(model_B),你就会发现子模块的名字为

sub_model.fc.weight
sub_model.fc.bias
fc.weight
fc.bias

我们观察到名字是以你的变量标识符命名的,这点和TensorFlow的命名机制完全不同,请注意。同时我们还观察到,只要是权值weight其命名后缀都是weight,同样偏置bias的后缀是bias,因此根据此,可以单独对权值进行L2正则[7],具体过程见[8]。

1.3 部分加载模型

根据上面的分析,我们便发现只要过滤掉不需要加载的模型的名字,即可实现部分模型加载了,例子如:

model= CNNNet()
checkpoint= torch.load('./model.pth.tar')for name, paramsin model.stgcn.st_gcn_networks.named_parameters():
	 params_name='stgcn.st_gcn_networks.'+nameif params_namein model.state_dict():
    	model.state_dict()[params_name].copy_(checkpoint['model'][params_name])

我们发现,通过这个代码,我们可以仅对model.stgcn.st_gcn_networks这个子模块的参数进行加载,而其他的参数保持初始化情况不变。

2 模型Fine-Tune

在模型的Fine-Tune(微调)或者联合调试过程中,我们经常需要固定某个模型的参数,而去调整其他模型的参数,主要方法有:

  • 通过切断某个模块的梯度流,但是这个会导致该模型前面的所有模型也没有梯度。
  • 通过设置某个模块的所有变量的requires_grad=False
  • 在优化器内设置需要进行梯度更新的变量。
    笔者在实践过程中最常用的是第三种方法,暂时只介绍第三种方法。代码很简单,例子如下:
trainable_vars=list(model.stgcn.weight_model.parameters())+ \list(model.stgcn.fcn.parameters())+ \list(model.stgcn.data_bn.parameters())+ \list(model.stgcn.dim_map.parameters())+ \list(model.aux_cls.parameters())                
opt= torch.optim.SGD(trainable_vars, lr=1e-4, momentum=0.9)

简单粗暴,但是其实很好用,当需要训练的变量很多,而需要固定的变量很少的时候,可以用对整个模型参数求补的方式求得,这里不多介绍了。
当需要对整个模型进行微调时,只需要:

opt= torch.optim.SGD(model.parameters(), lr=1e-6, momentum=0.9)

给每一层或者每个模型设置不同的学习率

在模型训练过程中,有些模块,比如对抗生成网络GAN[9]的生成器和判别器经常需要设置不同的学习率,以求得更好的效果或者不同模型之间的平衡。详细内容见我以前文章[8]中所述,这里不再累述。

Pytorch内置的模型

pytorch内置有一些经常使用的模型和其在大规模数据集上的预训练参数,只需要安装torchvision便可轻松使用,模型有:

  • resnet: resnet18,resnet34,resnet50,resnet101,resnet152
  • vgg: vgg11,vgg13, vgg16, vgg19
  • alexnet
  • densenet
  • inception
  • squeezenet

具体模型定义见:Github click me
使用方法很简单,如:

import torchvision.modelsas models
resnet18= models.resnet18(pretrained=False)

在这里如果指定pretrained=True可以联网加载预训练模型,但是由于因为你懂得的原因,所以需要你懂得的辅助工具,建议读者自行去下载模型的文件后手动加载。模型文件的地址可以在模型定义文件中找到,如[https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py]中的resnet模型:

model_urls={'resnet18':'https://download.pytorch.org/models/resnet18-5c106cde.pth','resnet34':'https://download.pytorch.org/models/resnet34-333f7ec4.pth','resnet50':'https://download.pytorch.org/models/resnet50-19c8e357.pth','resnet101':'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth','resnet152':'https://download.pytorch.org/models/resnet152-b121ed2d.pth',}

Reference

[1]. Kaiming He, Ross Girshick, Piotr Dollár. Rethinking ImageNet Pre-training[J]. arXiv preprint, https://arxiv.org/abs/1811.08883
[2]. 迁移学习与fine-tuning有什么区别?
[3]. Fine tuning
[4]. pytorch
[5]. Adam 算法
[6]. Kingma D P, Ba J. Adam: A method for stochastic optimization[J]. arXiv preprint arXiv:1412.6980, 2014.
[7]. 曲线拟合问题与L2正则
[8]. pytorch中的L2和L1正则化,自定义优化器设置等操作
[9]. Goodfellow I, Pouget-Abadie J, Mirza M, et al. Generative adversarial nets[C]//Advances in neural information processing systems. 2014: 2672-2680.

  • 作者:还能坚持
  • 原文链接:https://blog.csdn.net/qq_35091353/article/details/108237184
    更新时间:2022-10-13 09:59:32