Global Weight Decay in Keras

According to the github repo (https://github.com/fchollet/keras/issues/2717) there is no way to do global weight decay. I answered it here, so others who have the same problem do not have to look furhter for an answer.

To get global weight decay in keras regularizers have to be added to every layer in the model. In my models these layers are batch normalization (beta/gamma regularizer) and dense/convolutions (W_regularizer/b_regularizer) layers.

Layer wise regularization is described here: (https://keras.io/regularizers/).


There is no way to directly apply a "global" weight decay to a whole keras model at once.

However, as I describe here, you can employ weight decay on a model by looping through its layers and manually applying the regularizers on appropriate layers. Here's the relevant code snippet:

model = keras.applications.ResNet50(include_top=True, weights='imagenet')
alpha = 0.00002  # weight decay coefficient

for layer in model.layers:
    if isinstance(layer, keras.layers.Conv2D) or isinstance(layer, keras.layers.Dense):
        layer.add_loss(lambda layer=layer: keras.regularizers.l2(alpha)(layer.kernel))
    if hasattr(layer, 'bias_regularizer') and layer.use_bias:
        layer.add_loss(lambda layer=layer: keras.regularizers.l2(alpha)(layer.bias))

Posting the full code to apply weight decay on a Keras model (inspired by the above post):

# a utility function to add weight decay after the model is defined.
def add_weight_decay(model, weight_decay):
    if (weight_decay is None) or (weight_decay == 0.0):
        return

    # recursion inside the model
    def add_decay_loss(m, factor):
        if isinstance(m, tf.keras.Model):
            for layer in m.layers:
                add_decay_loss(layer, factor)
        else:
            for param in m.trainable_weights:
                with tf.keras.backend.name_scope('weight_regularizer'):
                    regularizer = lambda param=param: tf.keras.regularizers.l2(factor)(param)
                    m.add_loss(regularizer)

    # weight decay and l2 regularization differs by a factor of 2
    add_decay_loss(model, weight_decay/2.0)
    return