加载tf模型 正确率很低_tensorflow修改模型输入

2022-12-26 07:58:39

CTR模型如果用estimator训练出来,直接导出为savedModel格式,上线时会有冗余计算,一般需要只计算模型的一个子图即可。

假设有下面这样一个计算图:

6b1baa68f86b6ba5a66f019d4ad5da99.png

其中:

a是一个placeholder

b、d是一个variable

c是一个tf.add节点

out是一个tf.multiply节点

一般情况下,我们保存了这个图,然后再加载,输入把a以feed_dict的形式传入参数就可以对图进行计算,现在的需求是,把其中任意一个节点替换掉,比如不计算a+b,直接给一个输入c;或者d不再是常量,而是运行是传入参数。这些需求都可以用tf的API实现。具体参考代码:

保存模型:

import tensorflow.compat.v1 as tf

tf.disable_v2_behavior()

from tensorflow.python.framework.graph_util import convert_variables_to_constants

a = tf.placeholder(dtype=tf.float32, shape=(1, 2), name='a')
b = tf.Variable(4, dtype=tf.float32, name='b')

c = tf.add(a, b, name='c')
d = tf.Variable(2, dtype=tf.float32, name='d')

out = tf.multiply(c, d, name='out')

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    graph = convert_variables_to_constants(sess, sess.graph_def, ['out'])
    tf.train.write_graph(graph, '.', 'graph_placeholder.pb', as_text=False)
    tf.train.write_graph(graph, '.', 'graph_placeholder_txt.pb', as_text=True)

加载模型:

import numpy as np
#import tensorflow as tf
from google.protobuf import text_format
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()

with tf.Session() as sess:
    with open('./graph_placeholder.pb', 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        data = np.array([[3.],[2.]], np.float32)
        output = tf.import_graph_def(graph_def, input_map={'a:0': data}, return_elements=['out:0'])
        print(sess.run(output))

with tf.Session() as sess:
    with open('./graph_placeholder.pb', 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        data = np.array([[3.],[2.]], np.float32)
        output = tf.import_graph_def(graph_def, input_map={'c:0': 5.}, return_elements=['out:0'])
        print(sess.run(output))

with tf.Session() as sess:
    with open('./graph_placeholder_txt.pb', 'rb') as f:
        graph_def = tf.GraphDef()
        text_format.Merge(f.read(), graph_def)
        data = np.array([[7.], [3.]], np.float32)
        output = tf.import_graph_def(graph_def, input_map={'a:0': data, 'd:0': 9.}, return_elements=['out:0'])
        print(sess.run(output))

参考资料:

https://tang.su/2017/01/export-TensorFlow-network/

  • 作者:weixin_39873741
  • 原文链接:https://blog.csdn.net/weixin_39873741/article/details/110884205
    更新时间:2022-12-26 07:58:39