significance of "trainable" and "training" flag in tf.layers.batch_normalization

This is quite complicated. And in TF 2.0 the behavior is changed, see this:

https://github.com/tensorflow/tensorflow/blob/095272a4dd259e8acd3bc18e9eb5225e7a4d7476/tensorflow/python/keras/layers/normalization_v2.py#L26

About setting layer.trainable = False on a BatchNormalization layer:

The meaning of setting layer.trainable = False is to freeze the layer, i.e. its internal state will not change during training:
its trainable weights will not be updated during fit() or train_on_batch(), and its state updates will not be run. Usually, this does not necessarily mean that the layer is run in inference
mode (which is normally controlled by the training argument that can be passed when calling a layer). "Frozen state" and "inference mode"
are two separate concepts.

However, in the case of the BatchNormalization layer, setting
trainable = False on the layer means that the layer will be
subsequently run in inference mode
(meaning that it will use the moving mean and the moving variance to normalize the current batch,
rather than using the mean and variance of the current batch). This behavior has been introduced in TensorFlow 2.0, in order to enable layer.trainable = False to produce the most commonly expected behavior in the convnet fine-tuning use case. Note that:

  • This behavior only occurs as of TensorFlow 2.0. In 1.*, setting layer.trainable = False would freeze the layer but would not switch it to inference mode.
  • Setting trainable on an model containing other layers will recursively set the trainable value of all inner layers.
  • If the value of the trainable attribute is changed after calling compile() on a model, the new value doesn't take effect for this model until compile() is called again.

training controls whether to use the training-mode batchnorm (which uses statistics from this minibatch) or inference-mode batchnorm (which uses averaged statistics across the training data). trainable controls whether the variables created inside the batchnorm process are themselves trainable.


The batch norm has two phases:

1. Training:
   -  Normalize layer activations using `moving_avg`, `moving_var`, `beta` and `gamma` 
     (`training`* should be `True`.)
   -  update the `moving_avg` and `moving_var` statistics. 
     (`trainable` should be `True`)
2. Inference:
   -  Normalize layer activations using `beta` and `gamma`.
      (`training` should be `False`)

Example code to illustrate few cases:

#random image
img = np.random.randint(0,10,(2,2,4)).astype(np.float32)

# batch norm params initialized
beta = np.ones((4)).astype(np.float32)*1 # all ones 
gamma = np.ones((4)).astype(np.float32)*2 # all twos
moving_mean = np.zeros((4)).astype(np.float32) # all zeros
moving_var = np.ones((4)).astype(np.float32) # all ones

#Placeholders for input image
_input = tf.placeholder(tf.float32, shape=(1,2,2,4), name='input')

#batch Norm
out = tf.layers.batch_normalization(
       _input,
       beta_initializer=tf.constant_initializer(beta),
       gamma_initializer=tf.constant_initializer(gamma),
       moving_mean_initializer=tf.constant_initializer(moving_mean),
       moving_variance_initializer=tf.constant_initializer(moving_var),
       training=False, trainable=False)


update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
init_op = tf.global_variables_initializer()

 ## 2. Run the graph in a session 

 with tf.Session() as sess:

    # init the variables   
    sess.run(init_op)

    for i in range(2):
        ops, o = sess.run([update_ops, out], feed_dict={_input: np.expand_dims(img, 0)})
        print('beta', sess.run('batch_normalization/beta:0'))
        print('gamma', sess.run('batch_normalization/gamma:0'))
        print('moving_avg',sess.run('batch_normalization/moving_mean:0'))
        print('moving_variance',sess.run('batch_normalization/moving_variance:0'))
        print('out', np.round(o))
        print('')

When training=False and trainable=False:

  img = [[[4., 5., 9., 0.]...
  out = [[ 9. 11. 19.  1.]... 
  The activation is scaled/shifted using gamma and beta.

When training=True and trainable=False:

  out = [[ 2.  2.  3. -1.] ...
  The activation is normalized using `moving_avg`, `moving_var`, `gamma` and `beta`. 
  The averages are not updated.

When traning=True and trainable=True:

  The out is same as above, but the `moving_avg` and `moving_var` gets updated to new values.

  moving_avg [0.03249997 0.03499997 0.06499994 0.02749997]
  moving_variance [1.0791667 1.1266665 1.0999999 1.0925]