How to get the global_step when restoring checkpoints in Tensorflow?

General pattern is to have a global_step variable to keep track of steps

global_step = tf.Variable(0, name='global_step', trainable=False)
train_op = optimizer.minimize(loss, global_step=global_step)

Then you can save with

saver.save(sess, save_path, global_step=global_step)

When you restore, the value of global_step is restored as well


This is a bit of a hack, but the other answers did not work for me at all

ckpt = tf.train.get_checkpoint_state(checkpoint_dir) 

#Extract from checkpoint filename
step = int(os.path.basename(ckpt.model_checkpoint_path).split('-')[1])

Update 9/2017

I'm not sure if this started working due to updates, but the following method seems to be effective in getting global_step to update and load properly:

Create two ops. One to hold global_step and another to increment it:

    global_step = tf.Variable(0, trainable=False, name='global_step')
    increment_global_step = tf.assign_add(global_step,1,
                                            name = 'increment_global_step')

Now in your training loop run the increment op every time you run your training op.

sess.run([train_op,increment_global_step],feed_dict=feed_dict)

If you ever want to retrieve you global step value as an integer at any point, just use the following command after loading the model:

sess.run(global_step)

This can be useful for creating filenames or calculating what your current epoch is without having a second tensorflow Variable for holding that value. For instance, calculating the current epoch on loading would be something like:

loaded_epoch = sess.run(global_step)//(batch_size*num_train_records)

Tags:

Tensorflow