Tensorflow迁移学习加载预训练模型并修改变量shape和value的方法

2022年12月27日09:59:54

迁移学习在深度学习中是经常被使用的方法,指的是在一个任务中预训练的模型被用于另一个任务的模型训练,以加快模型训练,减少资源消耗。
然而网络搜索相关的话题,基本上只涉及加载预训练模型的特定变量值的方法,即不涉及预训练模型某个变量与当前任务网络对应变量shape改变的处理。
在具体的语音合成多说话人模型迁移学习得到单说话人模型的任务中,就涉及到了迁移变量改变shape的情况,将解决方法如下列出。

一. 问题来源

       语音合成多说话人模型迁移学习得到单说话人模型的任务中,涉及了迁移变量改变shape的情况。

       一个不可避免的是,多说话人模型中由于存在speaker embedding变量,并往往与其他变量拼接作为模块的输入,在单人模型中,speaker embedding不存在,且对应模块内部部分变量的维度也必然改变。

       另外,在实际处理时发现,由于之前多人建模与单人建模的文本字符输入表symbol list存在差异,导致预训练模型中的char embeddings变量在单人模型中不能直接使用。

       种种情况,涉及的都是变量shape发生了变化,但是预训练模型中该变量的一部分值仍然可用。我们当然可以直接舍弃该变量已有值,重新训练,但是很多情况下没有必要。通过实践,来讲讲在Tensorflow框架下该问题的解决方法。

二. 相关接口

       网络搜索“迁移学习”或者“加载预训练模型”等,能够找到一些方法。比如,tensorflow新模型怎么加载老数据?中,涉及到如下接口:

#新网络结构的搭建
...... 
 #利用预训练模型给新网络变量赋值时,忽略变量v3
restore_variable_list = tf.contrib.framework.get_variables_to_restore(exclude=["v3",]) 
saver = tf.train.Saver(restore_variable_list)
saver.restore(sess, old_checkpoint_path)

       需要注意的是,以上代码调用之前,需要先搭建新的网络结构,而get_variables_to_restore就是列出了新网络中的所有变量,变量是tf.Variable类型,这时候还没有加载预训练模型中的值。

       而在干货!如何修改在TensorFlow框架下训练保存的模型参数名称中,涉及如下接口:

#得到预训练模型中所有的参数(名字,形状)元组
for var_name, var_shape in tf.train.list_variables(old_checkpoint_path): 
		#得到上述参数的值
        var = tf.train.load_variable(old_checkpoint_path, var_name) 

       与前一例子相比,本例代码不需要先搭建新的网络结构,直接从预训练模型中遍历所有变量,取得其名称、形状和值。注意var也就是值,是numpy.array格式的。本例链接中提到,可以通过tf.Variable(var)修改变量值后重新构建变量,但是在实际应用中,在我们已经构建好了新的网络结构之后,再通过tf.Variable构建的是新的变量了,即使手动赋予变量相同的名字。

三. 解决方法

       首先确立问题的前提,我们有了一个预训练模型,并且构建了新的网络结构,但是该网络某个变量与预训练模型中的变量shape不一致。

       我们假定这两个变量同名,如char_embeddings,并且预训练模型中该变量shape为[30, 128],表示支持30个字符[a-z, @, #, %, &]输入,每个字符用128维表示。新的网络中我们只支持29个字符[a-z, @, #, %],这样新网络中,变量char_embeddings的shape为[29, 128],此时,加载预训练模型,同名变量赋值显然会报错。

       在实际操作时,我们首先通过(二)中的两类接口,获取两个网络结构不一致变量的完整名称及其shape。注意此时可以观察到,每个变量对应的两个变量char_embeddings/Adam和char_embeddings/Adam_1的shape也有相同的差异,这些变量也需要对应修改。

       然后通过上述list_variable和load_variable接口获取预训练模型特定接口的值,并进行修改,本例中使用var=var[:-1, :]就可以从原[30,128]中获取前29个字符的预训练值,临时保存。然后在新的网络结构搭建好之后,调用get_variables_to_restore获取当前网络的所有变量,像上例一中一样,忽略待修改变量以便在后续操作中restore其他变量。而待修改变量,通过tf.convert_to_tensor及tf.assign操作,将临时保存的修改值赋予当前网络对应变量,在restore操作后通过session.run进行实际赋值。详细代码如下:

#保存待修改变量名与修改后的变量值
map_name_newval = {}
for var_name, var_shape in tf.train.list_variables(old_checkpoint_path): 
		if var_name.find('char_embeddings'):
        		var = tf.train.load_variable(old_checkpoint_path, var_name) 
        		map_name_newval[val_name] = var[:-1, :] #修改值

#当前网络结构区分直接恢复预训练值的变量与修改变量
var_list = tf.contrib.framework.get_variables_to_restore() 
restore_var_list = []
map_vname_op = {}
for v in var_list:
		v_name = v.name.split(':')[0] #两种方式取得的模型名称有一点差异
		if v_name in map_name_newval:
				new_val = tf.convert_to_tensor(map_name_newval[v_name])
				v_assign = tf.assign(v, new_val) #给当前网络中的特定变量赋新值操作
				map_vname_op[v.name] = v_assign
		else:
				restore_val_list.append(v)

#实际变量赋值操作
with tf.Session() as sess:
		saver = tf.train.Saver(val_list=restore_variable_list)
		saver.restore(sess, old_checkpoint_path)
		for vname, op in map_vname_op.items():
				sess.run(op)

       本例中修改值比较简单,但是无论如何修改,整体的思路和流程都是一样的。

  • 作者:lujian1989
  • 原文链接:https://blog.csdn.net/lujian1989/article/details/104334685
    更新时间:2022年12月27日09:59:54 ,共 2861 字。