0

I defined model (mnist digits recognition) using tensorflow 2.15.0 and tensorflow.compat.v1. Model was **not ** trained and the graph was exported using following code:

init = tf.global_variables_initializer()
saver_def = tf.train.Saver().as_saver_def()

with open('frozen_models/graph_v1.pb', 'wb') as f:
  f.write(tf.get_default_graph().as_graph_def().SerializeToString())

Later I loaded this graph into C++ application and trained model using tensorflow c_api (version 2.18.0). I saved checkpoint using operation defined by python call to tf.train.Saver().as_saver_def(). The output files of this operation are:

  • checkpoint_1.data-00000-of-00001
  • checkpoint_1.index

The code was basically something like that


// initialization

save_op = TF_GraphOperationByName(graph, "save/control_dependency");
checkpoint_file.oper = TF_GraphOperationByName(graph, "save/Const");
checkpoint_file.index = 0;

// later

TF_Output inputs[] = { checkpoint_file };
TF_Tensor* input_values[] = { tensor };

const TF_Operation* ops[] = {
    type == CheckpointType::Save ? save_op : restore_op
};

TF_SessionRun(
    session,
    NULL,
    /* Inputs */
    inputs,
    input_values,
    1,
    /* Outputs */
    NULL,
    NULL,
    0,
    /* Init operation */
    ops,
    1,
    NULL,
    status
);

My question is - how can I load this graph into tensorflow (preferably 2.x.x) via python and restore weights so the model could be exported in SavedModel format? Is this even possible?

Disclaimer - I know that it would be easier to just to it in python but my goal is to learn how that toolchain/formats work, not to create ML model.

I tried few approaches but maybe I'm missing some necessary piece.

For example I imported graph using following code, but I don't really know where to go from here:

import tensorflow.compat.v1 as tf

with tf.io.gfile.GFile(graph_filename, "rb") as f:
    graph_def = tf.get_default_graph().as_graph_def()
    graph_str = f.read()
    graph_def.ParseFromString(f.read())
2
  • 1
    To correctly load your TF1 graph and checkpoint, you must use tensorflow.compat.v1 to run the graph's original restore operation within a session. For the complete implementation, please refer to this gist for the solution. Commented Oct 7 at 8:10
  • @Sagar it seems to work, thank you! Commented Oct 7 at 19:24

0

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.