26

I have an tensorflow .pb file which I would like to load into python DNN, restore the graph and get the predictions. I am doing this to test out whether the .pb file created can make the predictions similar to the normal Saver.save() model.

My basic problem is am getting a very different value of predictions when I make them on Android using the above mentioned .pb file

My .pb file creation code:

frozen_graph = tf.graph_util.convert_variables_to_constants(
        session,
        session.graph_def,
        ['outputLayer/Softmax']
    )
with open('frozen_model.pb', 'wb') as f:
  f.write(frozen_graph.SerializeToString())

So I have two major concerns:

  1. How can I load the above mentioned .pb file to python Tensorflow model ?
  2. Why am I getting completely different values of prediction in python and android ?

2 Answers 2

33

The following code will read the model and print out the names of the nodes in the graph.

import tensorflow as tf
from tensorflow.python.platform import gfile
GRAPH_PB_PATH = './frozen_model.pb'
with tf.Session() as sess:
   print("load graph")
   with gfile.FastGFile(GRAPH_PB_PATH,'rb') as f:
       graph_def = tf.GraphDef()
   graph_def.ParseFromString(f.read())
   sess.graph.as_default()
   tf.import_graph_def(graph_def, name='')
   graph_nodes=[n for n in graph_def.node]
   names = []
   for t in graph_nodes:
      names.append(t.name)
   print(names)

You are freezing the graph properly that is why you are getting different results basically weights are not getting stored in your model. You can use the freeze_graph.py (link) for getting a correctly stored graph.

Sign up to request clarification or add additional context in comments.

6 Comments

What is sess.graph.as_default() doing?
i get on graph_def.ParseFromString(f.read()) DecodeError: Error parsing message
Use tf.gfile.GFile instead of gfile.FastGFile in 2019
When I try to display node names with your program and facenet model , it is giving me this error ValueError: Input 0 of node InceptionResnetV1/Conv2d_1a_3x3/BatchNorm/cond/Switch was passed float from phase_train:0 incompatible with expected bool. Do you have any idea why is it happening ? Thanks
@Sneha You are probably passing wrong datatype. It expects bool but it is getting float.
|
9

Here is the updated code for tensorflow 2.

import tensorflow as tf

GRAPH_PB_PATH = './frozen_model.pb'
with tf.compat.v1.Session() as sess:
   print("load graph")
   with tf.io.gfile.GFile(GRAPH_PB_PATH,'rb') as f:
       graph_def = tf.compat.v1.GraphDef()
   graph_def.ParseFromString(f.read())
   sess.graph.as_default()
   tf.import_graph_def(graph_def, name='')
   graph_nodes=[n for n in graph_def.node]
   names = []
   for t in graph_nodes:
      names.append(t.name)
   print(names)

1 Comment

I got DecodeError: Error parsing message with type 'tensorflow.GraphDef'.

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.