Restoring a Tensorflow model that uses Iterators

I would suggest to use tf.contrib.data.make_saveable_from_iterator, which has been designed precisely for this purpose. It is much less verbose and does not require you to change existing code, in particular how you define your iterator.

Working example, when we save everything after step 5 has completed. Note how I don't even bother knowing what seed is used.

import tensorflow as tf

iterator = (
  tf.data.Dataset.range(100)
  .shuffle(10)
  .make_one_shot_iterator())
batch = iterator.get_next(name='batch')

saveable_obj = tf.contrib.data.make_saveable_from_iterator(iterator)
tf.add_to_collection(tf.GraphKeys.SAVEABLE_OBJECTS, saveable_obj)
saver = tf.train.Saver()

with tf.Session() as sess:
  tf.global_variables_initializer().run()
  for step in range(10):
    print('{}: {}'.format(step, sess.run(batch)))
    if step == 5:
      saver.save(sess, './foo', global_step=step)

# 0: 1
# 1: 6
# 2: 7
# 3: 3
# 4: 8
# 5: 10
# 6: 12
# 7: 14
# 8: 5
# 9: 17

Then later, if we resume from step 6, we get the same output.

import tensorflow as tf

saver = tf.train.import_meta_graph('./foo-5.meta')
with tf.Session() as sess:
  saver.restore(sess, './foo-5')
  for step in range(6, 10):
    print('{}: {}'.format(step, sess.run('batch:0')))
# 6: 12
# 7: 14
# 8: 5
# 9: 17

When restoring a saved meta graph, you can restore the initialization operation with name and then use it again to initialize the input pipeline for inference.

That is, when creating the graph, you can do

    dataset_init_op = iterator.make_initializer(dataset, name='dataset_init')

And then restore this operation by doing:

    dataset_init_op = graph.get_operation_by_name('dataset_init')

Here is a self contained code snippet that compares results of a randomly initialized model before and after restoring.

Saving an Iterator

np.random.seed(42)
data = np.random.random([4, 4])
X = tf.placeholder(dtype=tf.float32, shape=[4, 4], name='X')
dataset = tf.data.Dataset.from_tensor_slices(X)
iterator = tf.data.Iterator.from_structure(dataset.output_types, dataset.output_shapes)
dataset_next_op = iterator.get_next()

# name the operation
dataset_init_op = iterator.make_initializer(dataset, name='dataset_init')

w = np.random.random([1, 4])
W = tf.Variable(w, name='W', dtype=tf.float32)
output = tf.multiply(W, dataset_next_op, name='output')     
sess = tf.Session()
saver = tf.train.Saver()
sess.run(tf.global_variables_initializer())
sess.run(dataset_init_op, feed_dict={X:data})
while True:
    try:
        print(sess.run(output))
    except tf.errors.OutOfRangeError:
        saver.save(sess, 'tmp/', global_step=1002)
    break

And then you can restore the same model for inference as follows:

Restoring saved iterator

np.random.seed(42)
data = np.random.random([4, 4])
tf.reset_default_graph()
sess = tf.Session()
saver = tf.train.import_meta_graph('tmp/-1002.meta')
ckpt = tf.train.get_checkpoint_state(os.path.dirname('tmp/checkpoint'))
saver.restore(sess, ckpt.model_checkpoint_path)
graph = tf.get_default_graph()

# Restore the init operation
dataset_init_op = graph.get_operation_by_name('dataset_init')

X = graph.get_tensor_by_name('X:0')
output = graph.get_tensor_by_name('output:0')
sess.run(dataset_init_op, feed_dict={X:data})
while True:
try:
    print(sess.run(output))
except tf.errors.OutOfRangeError:
    break

I couldn't solve the problem related to initializing the iterator, but since I pre-process my dataset using map method, and I apply transformations defined by Python operations wrapped with py_func, which cannot be serialized for storing\restoring, I'll have to initialize my dataset when I want to restore it anyway.

So, the problem that remains is how to feed data to my graph when I restore it. I placed a tf.identity node between the iterator output and my network input. Upon restoring, I feed my data to the identity node. A better solution that I discovered later is using placeholder_with_default(), as described in this answer.