pytorch使用(五)使用pytorch进行微调(fine-tuning)

2022-09-24 14:37:16

pytorch使用:目录


pytorch使用(五)使用pytorch进行微调(fine-tuning)

在使用pytorch的时候,发现使用预训练的模型进行微调的时候有比较难的两步,一是如何加载需要的两部分模型

1. 定义网络并且加载网络参数
  • 首先定义自己模型并且加载预训练网络的模型和参数,定义自己模型的时候把想要用的层名字设置为和预训练模型一样的
  • 加载预训练模型中的参数到自己的模型
#load the pre-trained network
model_zero = C3D()
model_zero.load_state_dict(torch.load(paraPath))

model = ROI_C3D(classes=para['nClass'])#ROI_C3Dis my net
model_dict = model.state_dict()

model_zero = {k: vfor k, vin model_zero.state_dict().items()if kin model_dict}
model_dict.update(model_zero)
model.load_state_dict(model_dict)
2. 设置学习率

通常预训练层的学习率会低一些. 在下面这个例子中,在定义网络的时候,相比原来的模型,将最后一个全连接的名字改为了classifier

#set optimizationmethodignored_params =list(map(id, model.classifier.parameters())) #layerneedtobetrainedbase_params =filter(lambda p: id(p)notinignored_params,model.parameters())optimizer =optim.SGD([
    {'params': base_params},
    {'params': model.classifier.parameters(), 'lr': para['lr']*0.1}],0.001, momentum=0.9, weight_decay=1e-4)

这样预训练的模型学习率是0.0001,而最后一个全连接是0.001

  • 作者:贪泉觉爽
  • 原文链接:https://blog.csdn.net/GYGuo95/article/details/79945631
    更新时间:2022-09-24 14:37:16