如何在python中从.pb文件中恢复Tensorflow模型?

26 人关注

我有一个tensorflow .pb文件,我想把它加载到python DNN中,恢复图形并获得预测结果。我这样做是为了测试创建的.pb文件是否能做出类似于正常Saver.save()模型的预测。

我的基本问题是,当我在安卓上使用上述.pb文件进行预测时,得到的预测值非常不同。

我的.pb文件创建代码。

frozen_graph = tf.graph_util.convert_variables_to_constants(
        session,
        session.graph_def,
        ['outputLayer/Softmax']
with open('frozen_model.pb', 'wb') as f:
  f.write(frozen_graph.SerializeToString())

因此,我有两个主要关切。

  • How can I load the above mentioned .pb file to python Tensorflow model ?
  • Why am I getting completely different values of prediction in python and android ?
  • android
    python
    tensorflow
    vizsatiz
    vizsatiz
    发布于 2018-06-01
    2 个回答
    Pranjal Sahu
    Pranjal Sahu
    发布于 2020-05-07
    已采纳
    0 人赞同

    下面的代码将读取模型并打印出图中节点的名称。

    import tensorflow as tf
    from tensorflow.python.platform import gfile
    GRAPH_PB_PATH = './frozen_model.pb'
    with tf.Session() as sess:
       print("load graph")
       with gfile.FastGFile(GRAPH_PB_PATH,'rb') as f:
           graph_def = tf.GraphDef()
       graph_def.ParseFromString(f.read())
       sess.graph.as_default()
       tf.import_graph_def(graph_def, name='')
       graph_nodes=[n for n in graph_def.node]
       names = []
       for t in graph_nodes:
          names.append(t.name)
       print(names)
    

    你正在正确地冻结图形,这就是为什么你会得到不同的结果,基本上权重没有被存储在你的模型中。你可以使用freeze_graph.py (link),以获得一个正确存储的图形。

    What is sess.graph.as_default() doing?
    I get on 【替换代码0 DecodeError:解读信息错误
    Use tf.gfile.GFile instead of gfile.FastGFile in 2019
    当我试图用你的程序和facenet模型来显示节点名称时,它给了我这个错误 ValueError: Input 0 of node InceptionResnetV1/Conv2d_1a_3x3/BatchNorm/cond/Switch was passed float from phase_train:0 incompatible with expected bool. 。你知道为什么会出现这种情况吗?谢谢
    @Sneha 你可能传错了数据类型。它期望的是bool,但它得到的是float。
    caylus
    caylus
    发布于 2020-05-07
    0 人赞同

    这里是tensorflow 2的更新代码。

    import tensorflow as tf
    GRAPH_PB_PATH = './frozen_model.pb'
    with tf.compat.v1.Session() as sess:
       print("load graph")
       with tf.io.gfile.GFile(GRAPH_PB_PATH,'rb') as f:
           graph_def = tf.compat.v1.GraphDef()
       graph_def.ParseFromString(f.read())
       sess.graph.as_default()
       tf.import_graph_def(graph_def, name='')
       graph_nodes=[n for n in graph_def.node]