What is the sequence of SessionRunHook's member function to be called?

You can find a tutorial here, a little long but you can jump the part of building the network. Or you can read my small summary below, based on my experiance.

First, MonitoredSession should be used instead of normal Session.

A SessionRunHook extends session.run() calls for the MonitoredSession.

Then some common SessionRunHook classes can be found here. A simple one is LoggingTensorHook but you might want to add the following line after your imports for seeing the logs when running:

tf.logging.set_verbosity(tf.logging.INFO)

Or you have option to implement your own SessionRunHook class. A simple one is from cifar10 tutorial

class _LoggerHook(tf.train.SessionRunHook):
  """Logs loss and runtime."""

  def begin(self):
    self._step = -1
    self._start_time = time.time()

  def before_run(self, run_context):
    self._step += 1
    return tf.train.SessionRunArgs(loss)  # Asks for loss value.

  def after_run(self, run_context, run_values):
    if self._step % FLAGS.log_frequency == 0:
      current_time = time.time()
      duration = current_time - self._start_time
      self._start_time = current_time

      loss_value = run_values.results
      examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration
      sec_per_batch = float(duration / FLAGS.log_frequency)

      format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                    'sec/batch)')
      print (format_str % (datetime.now(), self._step, loss_value,
                           examples_per_sec, sec_per_batch))

where loss is defined outside the class. This _LoggerHook uses print to print the information while LoggingTensorHook uses tf.logging.INFO.

At last, for better understanding how it works, the execution order is presented by pseudocode with MonitoredSession here:

  call hooks.begin()
  sess = tf.Session()
  call hooks.after_create_session()
  while not stop is requested:  # py code: while not mon_sess.should_stop():
    call hooks.before_run()
    try:
      results = sess.run(merged_fetches, feed_dict=merged_feeds)
    except (errors.OutOfRangeError, StopIteration):
      break
    call hooks.after_run()
  call hooks.end()
  sess.close()

Hope this helps.


tf.SessionRunHook enables you to add your custom code during each session run command you execute in your code. To understand it, I have created a simple example below:

  1. We want to print loss values after each update of the parameters.
  2. We will use SessionRunHook to achieve this.

Create a tensorflow Graph

import tensorflow as tf
import numpy as np

x = tf.placeholder(shape=(10, 2), dtype=tf.float32)
w = tf.Variable(initial_value=[[10.], [10.]])
w0 = [[1], [1.]]
y = tf.matmul(x, w0)
loss = tf.reduce_mean((tf.matmul(x, w) - y) ** 2)
optimizer = tf.train.AdamOptimizer(0.001).minimize(loss)

Creating the Hook

class _Hook(tf.train.SessionRunHook):
  def __init__(self, loss):
    self.loss = loss

  def begin(self):
    pass

  def before_run(self, run_context):
    return tf.train.SessionRunArgs(self.loss)  

  def after_run(self, run_context, run_values):
    loss_value = run_values.results
    print("loss value:", loss_value)

Creating a monitored Session with a hook

sess = tf.train.MonitoredSession(hooks=[_Hook(loss)])

train

for _ in range(10):
  x_ = np.random.random((10, 2))
  sess.run(optimizer, {x: x_})
# Output
loss value: 21.244701
loss value: 19.39169
loss value: 16.02665
loss value: 16.717144
loss value: 15.389178
loss value: 16.23935
loss value: 14.299083
loss value: 9.624525
loss value: 5.654896
loss value: 10.689494

Tags:

Tensorflow