Get Gradients with Keras Tensorflow 2.0

To compute the gradients of the loss against the weights, use

with tf.GradientTape() as tape:
    loss = model(model.trainable_weights)

tape.gradient(loss, model.trainable_weights)

This is (arguably poorly) documented on GradientTape.

We do not need to tape.watch the variable because trainable parameters are watched by default.

As a function, it can be written as

def gradient(model, x):
    x_tensor = tf.convert_to_tensor(x, dtype=tf.float32)
    with tf.GradientTape() as t:
        t.watch(x_tensor)
        loss = model(x_tensor)
    return t.gradient(loss, x_tensor).numpy()

Also have a look here: https://github.com/tensorflow/tensorflow/issues/31542#issuecomment-630495970

richardwth wrote a child class of Tensorboard.

I adapted it as follows:

class ExtendedTensorBoard(tf.keras.callbacks.TensorBoard):
    def _log_gradients(self, epoch):
        writer = self._writers['train']

        with writer.as_default(), tf.GradientTape() as g:
            # here we use test data to calculate the gradients
            features, y_true = list(val_dataset.batch(100).take(1))[0]

            y_pred = self.model(features)  # forward-propagation
            loss = self.model.compiled_loss(y_true=y_true, y_pred=y_pred)  # calculate loss
            gradients = g.gradient(loss, self.model.trainable_weights)  # back-propagation

            # In eager mode, grads does not have name, so we get names from model.trainable_weights
            for weights, grads in zip(self.model.trainable_weights, gradients):
                tf.summary.histogram(
                    weights.name.replace(':', '_') + '_grads', data=grads, step=epoch)

        writer.flush()

    def on_epoch_end(self, epoch, logs=None):
        # This function overwrites the on_epoch_end in tf.keras.callbacks.TensorBoard
        # but we do need to run the original on_epoch_end, so here we use the super function.
        super(ExtendedTensorBoard, self).on_epoch_end(epoch, logs=logs)

        if self.histogram_freq and epoch % self.histogram_freq == 0:
            self._log_gradients(epoch)