Tensorflow Convert pb file to TFLITE using python

You can convert to tflite directly in python directly. You have to freeze the graph and use toco_convert. It needs the input and output names and shapes to be determined ahead of calling the API just like in the commandline case.

An example code snippet

Copied from documentation, where a "frozen" (no variables) graph is defined as part of your code:

import tensorflow as tf

img = tf.placeholder(name="img", dtype=tf.float32, shape=(1, 64, 64, 3))
val = img + tf.constant([1., 2., 3.]) + tf.constant([1., 4., 4.])
out = tf.identity(val, name="out")
with tf.Session() as sess:
  tflite_model = tf.contrib.lite.toco_convert(sess.graph_def, [img], [out])
  open("test.tflite", "wb").write(tflite_model)

In the example above, there is no freeze graph step since there are no variables. If you have variables and run toco without freezing graph, i.e. converting those variables to constants first, then toco will complain!

If you have frozen graphdef and know the inputs and outputs

Then you don't need the session. You can directly call toco API:

path_to_frozen_graphdef_pb = '...'
input_tensors = [...]
output_tensors = [...]
frozen_graph_def = tf.GraphDef()
with open(path_to_frozen_graphdef_pb, 'rb') as f:
  frozen_graph_def.ParseFromString(f.read())
tflite_model = tf.contrib.lite.toco_convert(frozen_graph_def, input_tensors, output_tensors)

If you have non-frozen graphdef and know the inputs and outputs

Then you have to load the session and freeze the graph first before calling toco:

path_to_graphdef_pb = '...'
g = tf.GraphDef()
with open(path_to_graphdef_pb, 'rb') as f:
  g.ParseFromString(f.read())
output_node_names = ["..."]
input_tensors = [..]
output_tensors = [...]

with tf.Session(graph=g) as sess:
  frozen_graph_def = tf.graph_util.convert_variables_to_constants(
      sess, sess.graph_def, output_node_names)
# Note here we are passing frozen_graph_def obtained in the previous step to toco.
tflite_model = tf.contrib.lite.toco_convert(frozen_graph_def, input_tensors, output_tensors)

If you don't know inputs / outputs of the graph

This can happen if you did not define the graph, ex. you downloaded the graph from somewhere or used a high level API like the tf.estimators that hide the graph from you. In this case, you need to load the graph and poke around to figure out the inputs and outputs before calling toco. See my answer to this SO question.


Following this TF example you can pass "--Saved_model_dir" parameter to export the saved_model.pb and variables folder to some directory (none existing dir) before running retrain.py script:

python retrain.py ...... --saved_model_dir /home/..../export

In order to convert your model to tflite you need to use below line:

convert_saved_model.convert(saved_model_dir='/home/.../export',output_arrays="final_result",output_tflite='/home/.../export/graph.tflite')

Note: you need to import convert_saved_model:

from tensorflow.contrib.lite.python import convert_saved_model

Remember you can convert to tflite in 2 ways:

enter image description here

But the easiest way is to export saved_model.pb with variables in case you want to avoid using builds tools like Bazel.


This is what worked for me: (SSD_InceptionV2 model)

  1. After finishing the training. i used model_main.py from object_detection folder. TFv1.11
  2. ExportGraph as TFLITE:
python /tensorflow/models/research/object_detection/export_tflite_ssd_graph.py

--pipeline_config_path annotations/ssd_inception_v2_coco.config 
--trained_checkpoint_prefix trained-inference-graphs/inference_graph_v7.pb/model.ckpt 
--output_directory trained-inference-graphs/inference_graph_v7.pb/tflite 
--max_detections 3
  1. This generates a .pb file so you can generate the tflite file from it like this:
tflite_convert 
--output_file=test.tflite 
--graph_def_file=tflite_graph.pb 
--input_arrays=normalized_input_image_tensor 
--output_arrays='TFLite_Detection_PostProcess','TFLite_Detection_PostProcess:1','TFLite_Detection_PostProcess:2','TFLite_Detection_PostProcess:3'

--input_shape=1,300,300,3 
--allow_custom_ops

Now the inputs/outputs i am not 100 sure how to get this but this code helps me before:

import tensorflow as tf
frozen='/tensorflow/mobilenets/mobilenet_v1_1.0_224.pb'
gf = tf.GraphDef()
gf.ParseFromString(open(frozen,'rb').read())
[n.name + '=>' +  n.op for n in gf.node if n.op in ( 'Softmax','Placeholder')]    
[n.name + '=>' +  n.op for n in gf.node if n.op in ( 'Softmax','Mul')]