connect input and output tensors of two different graphs tensorflow

Accepted answer does connect two graphs, however it does not restore the collections, global and trainable variables. After an exhaustive search I came to a better solution:

import tensorflow as tf
from tensorflow.python.framework import meta_graph

with tf.Graph().as_default() as graph1:
    input = tf.placeholder(tf.float32, (None, 20), name='input')
    output = tf.identity(input, name='output')

with tf.Graph().as_default() as graph2:
    input = tf.placeholder(tf.float32, (None, 20), name='input')
    output = tf.identity(input, name='output')

graph = tf.get_default_graph()
x = tf.placeholder(tf.float32, (None, 20), name='input')

We use tf.train.export_meta_graph that exports also CollectionDef and meta_graph.import_scoped_meta_graph to import it. Here is where the connection happens, specifically in input_map parameter.

meta_graph1 = tf.train.export_meta_graph(graph=graph1)
meta_graph.import_scoped_meta_graph(meta_graph1, input_map={'input': x}, import_scope='graph1')
out1 = graph.get_tensor_by_name('graph1/output:0')

meta_graph2 = tf.train.export_meta_graph(graph=graph2)
meta_graph.import_scoped_meta_graph(meta_graph2, input_map={'input': out1}, import_scope='graph2')

Now graph is connected as well as global variables are being re-mapped.

print(tf.global_variables())

You can also import meta graphs directly from a file.


Assuming that your Protobuf files contain serialized tf.GraphDef protos, you can use the input_map argument of tf.import_graph_def() to connect the two graphs:

# Import graph1.
graph1_def = ...  # tf.GraphDef object
out1_name = "..."  # name of the graph1out tensor in graph1_def.
graph1out, = tf.import_graph_def(graph1_def, return_elements=[out_name])

# Import graph2 and connect it to graph1.
graph2_def = ...  # tf.GraphDef object
inp2_name = "..."  # name of the graph2inp tensor in graph2_def.
out2_name = "..."  # name of the graph2out tensor in graph2_def.
graph2out, = tf.import_graph_def(graph2_def, input_map={inp2_name: graph1out},
                                 return_elements=[out2_name])