Batch Normalization in tf.keras does not calculate average mean and average variance

This is because tf.keras.layers.BatchNormalization inherits from tf.keras.layers.Layer. Keras API handle update ops as part of its fit and evaluate loops. This in turn means that it won't update tf.GraphKeys.UPDATE_OPS collection without it.

So in order to make it work, you need to update it manually

hidden = tf.keras.layers.Dense(units, activation=None)(out)
batch_normed = tf.keras.layers.BatchNormalization(trainable=True) 
layer = batch_normed(hidden)

This creates separate class instance

tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, batch_normed.updates)

And this updates needed collection. Also take a look https://github.com/tensorflow/tensorflow/issues/25525


tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, bn1.updates[0])
tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, bn1.updates[1])
updates_op = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

this can solve

tf.control_dependencies(update_ops)

error problem.

if use

tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, batch_normed.updates)

the return of

tf.get_collection(tf.GraphKeys.UPDATE_OPS)

is a list in list just like [[something]]

and use

tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, bn1.updates[0])
tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, bn1.updates[1])
updates_op = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

the return of

tf.get_collection(tf.GraphKeys.UPDATE_OPS)

is [something1,something2,...]

i thinks this is the solution.

but the out put is different,and i don't know which is true.