TensorFlow Inference

While it's not as cut and dry as model.predict(), it's still really trivial.

In your model you should have a tensor that computes the final output you're interested in, let's name that tensor output. You may currently just have a loss function. If so create another tensor (variable in the model) that actually computes the output you want.

For example, if your loss function is:

tf.nn.sigmoid_cross_entropy_with_logits(last_layer_activation, labels)

And you expect your outputs to be in the range [0,1] per class, create another variable:

output = tf.sigmoid(last_layer_activation)

Now, when you call sess.run(...) just request the output tensor. Don't request the optimization OP you normally would to train it. When you request this variable tensorflow will do the minimum work necessary to produce the value (e.g. it won't bother with backprop, loss functions, and all that because a simple feed forward pass is all that's necessary to compute output.

So if you're creating a service to return inferences of the model you'll want to keep the model loaded in memory/gpu, and repeat:

sess.run(output, feed_dict={X: input_data})

You won't need to feed it the labels because tensorflow won't bother to compute ops that aren't needed to produce the output you are requesting. You don't have to change your model or anything.

While this approach might not be as obvious as model.predict(...) I'd argue that it's vastly more flexible. If you start playing with more complex models you'll probably learn to love this approach. model.predict() is like "thinking inside the box."


Alright, this took way too much time to figure out; so here is the answer for the rest of the world.

Quick Reminder: I needed to persist a model that can be dynamically loaded and inferred against without knowledge as to the under pinnings or insides of how it works.

Step 1: Create a model as a Class and ideally use an interface definition

class Vgg3Model:

    NUM_DENSE_NEURONS = 50
    DENSE_RESHAPE = 32 * (CONSTANTS.IMAGE_SHAPE[0] // 2) * (CONSTANTS.IMAGE_SHAPE[1] // 2)

    def inference(self, images):
        '''
        Portion of the compute graph that takes an input and converts it into a Y output
        '''
        with tf.variable_scope('Conv1') as scope:
            C_1_1 = ld.cnn_layer(images, (5, 5, 3, 32), (1, 1, 1, 1), scope, name_postfix='1')
            C_1_2 = ld.cnn_layer(C_1_1, (5, 5, 32, 32), (1, 1, 1, 1), scope, name_postfix='2')
            P_1 = ld.pool_layer(C_1_2, (1, 2, 2, 1), (1, 2, 2, 1), scope)
        with tf.variable_scope('Dense1') as scope:
            P_1 = tf.reshape(P_1, (-1, self.DENSE_RESHAPE))
            dim = P_1.get_shape()[1].value
            D_1 = ld.mlp_layer(P_1, dim, self.NUM_DENSE_NEURONS, scope, act_func=tf.nn.relu)
        with tf.variable_scope('Dense2') as scope:
            D_2 = ld.mlp_layer(D_1, self.NUM_DENSE_NEURONS, CONSTANTS.NUM_CLASSES, scope)
        H = tf.nn.softmax(D_2, name='prediction')
        return H

    def loss(self, logits, labels):
        '''
        Adds Loss to all variables
        '''
        cross_entr = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=labels)
        cross_entr = tf.reduce_mean(cross_entr)
        tf.summary.scalar('cost', cross_entr)
        tf.add_to_collection('losses', cross_entr)
        return tf.add_n(tf.get_collection('losses'), name='total_loss')

Step 2: Train your network with whatever inputs you want; in my case I used Queue Runners and TF Records. Note that this step is done by a different team which iterates, builds, designs and optimizes models. This can also change over time. The output they produce must be able to be pulled from a remote location so we can dynamically load the updated models on devices (reflashing hardware is a pain especially if it is geographically distributed). In this instance; the team drops the 3 files associated with a graph saver; but also a pickle of the model used for that training session

model = vgg3.Vgg3Model()

def create_sess_ops():
    '''
    Creates and returns operations needed for running
    a tensorflow training session
    '''
    GRAPH = tf.Graph()
    with GRAPH.as_default():
        examples, labels = Inputs.read_inputs(CONSTANTS.RecordPaths,
                                          batch_size=CONSTANTS.BATCH_SIZE,
                                          img_shape=CONSTANTS.IMAGE_SHAPE,
                                          num_threads=CONSTANTS.INPUT_PIPELINE_THREADS)
        examples = tf.reshape(examples, [-1, CONSTANTS.IMAGE_SHAPE[0],
                                     CONSTANTS.IMAGE_SHAPE[1], CONSTANTS.IMAGE_SHAPE[2]], name='infer/input')
        logits = model.inference(examples)
        loss = model.loss(logits, labels)
        OPTIMIZER = tf.train.AdamOptimizer(CONSTANTS.LEARNING_RATE)
        gradients = OPTIMIZER.compute_gradients(loss)
        apply_gradient_op = OPTIMIZER.apply_gradients(gradients)
        gradients_summary(gradients)
        summaries_op = tf.summary.merge_all()
        return [apply_gradient_op, summaries_op, loss, logits], GRAPH

def main():
    '''
    Run and Train CIFAR 10
    '''
    print('starting...')
    ops, GRAPH = create_sess_ops()
    total_duration = 0.0
    with tf.Session(graph=GRAPH) as SESSION:
        COORDINATOR = tf.train.Coordinator()
        THREADS = tf.train.start_queue_runners(SESSION, COORDINATOR)
        SESSION.run(tf.global_variables_initializer())
        SUMMARY_WRITER = tf.summary.FileWriter('Tensorboard/' + CONSTANTS.MODEL_NAME, graph=GRAPH)
        GRAPH_SAVER = tf.train.Saver()

        for EPOCH in range(CONSTANTS.EPOCHS):
            duration = 0
            error = 0.0
            start_time = time.time()
            for batch in range(CONSTANTS.MINI_BATCHES):
                _, summaries, cost_val, prediction = SESSION.run(ops)
                error += cost_val
            duration += time.time() - start_time
            total_duration += duration
            SUMMARY_WRITER.add_summary(summaries, EPOCH)
            print('Epoch %d: loss = %.2f (%.3f sec)' % (EPOCH, error, duration))
            if EPOCH == CONSTANTS.EPOCHS - 1 or error < 0.005:
                print(
                'Done training for %d epochs. (%.3f sec)' % (EPOCH, total_duration)
            )
                break
        GRAPH_SAVER.save(SESSION, 'models/' + CONSTANTS.MODEL_NAME + '.model')
        with open('models/' + CONSTANTS.MODEL_NAME + '.pkl', 'wb') as output:
            pickle.dump(model, output)
        COORDINATOR.request_stop()
        COORDINATOR.join(THREADS)

Step 3: Run some Inference. Load your pickled model; create a new graph by piping in the new placeholder to the logits; and then call session restore. DO NOT RESTORE THE WHOLE GRAPH; JUST THE VARIABLES.

MODEL_PATH = 'models/' + CONSTANTS.MODEL_NAME + '.model'
imgs_bsdir = 'C:/data/cifar_10/train/'

images = tf.placeholder(tf.float32, shape=(1, 32, 32, 3))
with open('models/vgg3.pkl', 'rb') as model_in:
model = pickle.load(model_in)
logits = model.inference(images)

def run_inference():
    '''Runs inference against a loaded model'''
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        new_saver = tf.train.Saver()
        new_saver.restore(sess, MODEL_PATH)
        print("Starting...")
        for i in range(20, 30):
            print(str(i) + '.png')
            img = misc.imread(imgs_bsdir + str(i) + '.png').astype(np.float32) / 255.0
            img = img.reshape(1, 32, 32, 3)
            pred = sess.run(logits, feed_dict={images : img})
            max_node = np.argmax(pred)
            print('predicted label: ' + str(max_node))
        print('done')

run_inference()

There definitely ways to improve on this using interfaces and maybe packaging up everything better; but this is working and sets the stage for how we will be moving forward.

FINAL NOTE When we finally pushed this to production, we ended up having to ship the stupid `mymodel_model.py file down with everything to build up the graph. So we now enforce a naming convention for all models and there is also a coding standard for production model runs so we can do this properly.

Good Luck!