How to properly use tf.metrics.accuracy?

TL;DR

The accuracy function tf.metrics.accuracy calculates how often predictions matches labels based on two local variables it creates: total and count, that are used to compute the frequency with which logits matches labels.

acc, acc_op = tf.metrics.accuracy(labels=tf.argmax(labels, 1), 
                                  predictions=tf.argmax(logits,1))

print(sess.run([acc, acc_op]))
print(sess.run([acc]))
# Output
#[0.0, 0.66666669]
#[0.66666669]
  • acc (accuracy): simply returns the metrics using total and count, doesnt update the metrics.
  • acc_op (update up): updates the metrics.

To understand why the acc returns 0.0, go through the details below.


Details using a simple example:

logits = tf.placeholder(tf.int64, [2,3])
labels = tf.Variable([[0, 1, 0], [1, 0, 1]])

acc, acc_op = tf.metrics.accuracy(labels=tf.argmax(labels, 1),   
                                  predictions=tf.argmax(logits,1))

Initialize the variables:

Since metrics.accuracy creates two local variables total and count, we need to call local_variables_initializer() to initialize them.

sess = tf.Session()

sess.run(tf.local_variables_initializer())
sess.run(tf.global_variables_initializer())

stream_vars = [i for i in tf.local_variables()]
print(stream_vars)

#[<tf.Variable 'accuracy/total:0' shape=() dtype=float32_ref>,
# <tf.Variable 'accuracy/count:0' shape=() dtype=float32_ref>]

Understanding update ops and accuracy calculation:

print('acc:',sess.run(acc, {logits:[[0,1,0],[1,0,1]]}))
#acc: 0.0

print('[total, count]:',sess.run(stream_vars)) 
#[total, count]: [0.0, 0.0]

The above returns 0.0 for accuracy as total and count are zeros, inspite of giving matching inputs.

print('ops:', sess.run(acc_op, {logits:[[0,1,0],[1,0,1]]})) 
#ops: 1.0

print('[total, count]:',sess.run(stream_vars)) 
#[total, count]: [2.0, 2.0]

With the new inputs, the accuracy is calculated when the update op is called. Note: since all the logits and labels match, we get accuracy of 1.0 and the local variables total and count actually give total correctly predicted and the total comparisons made.

Now we call accuracy with the new inputs (not the update ops):

print('acc:', sess.run(acc,{logits:[[1,0,0],[0,1,0]]}))
#acc: 1.0

Accuracy call doesnt update the metrics with the new inputs, it just returns the value using the two local variables. Note: the logits and labels dont match in this case. Now calling update ops again:

print('op:',sess.run(acc_op,{logits:[[0,1,0],[0,1,0]]}))
#op: 0.75 
print('[total, count]:',sess.run(stream_vars)) 
#[total, count]: [3.0, 4.0]

The metrics are updated to new inputs


For more information on how to use the metrics during training and how to reset them during validation, can be found here.


On TF 2.0, if you are using the tf.keras API, you can define a custom class myAccuracy which inherits from tf.keras.metrics.Accuracy, and overrides the update method like this:

# imports
# ...
class myAccuracy(tf.keras.metrics.Accuracy):
    def update_state(self, y_true, y_pred, sample_weight=None):
        y_true = tf.argmax(y_true,1)
        y_pred = tf.argmax(y_pred,1)
        return super(myAccuracy,self).update_state(y_true,y_pred,sample_weight)

Then, when compiling the model you can add metrics in the usual way.

from my_awesome_models import discriminador

discriminador.compile(tf.keras.optimizers.Adam(),
                      loss=tf.nn.softmax_cross_entropy_with_logits,
                      metrics=[myAccuracy()])

from my_puzzling_datasets import train_dataset,test_dataset

discriminador.fit(train_dataset.shuffle(70000).repeat().batch(1000), 
                  epochs=1,steps_per_epoch=1, 
                  validation_data=test_dataset.shuffle(70000).batch(1000), 
                  validation_steps=1)

# Train for 1 steps, validate for 1 steps
# 1/1 [==============================] - 3s 3s/step - loss: 0.1502 - accuracy: 0.9490 - val_loss: 0.1374 - val_accuracy: 0.9550

Or evaluate yout model over the whole dataset

discriminador.evaluate(test_dataset.batch(TST_DSET_LENGTH))
#> [0.131587415933609, 0.95354694]

Tags:

Tensorflow