Create keras callback to save model predictions and targets for each batch during training

NOTE: this answer is outdated and only works with TF1. Check @bers's answer for a solution tested on TF2.


After model compilation, the placeholder tensor for y_true is in model.targets and y_pred is in model.outputs.

To save the values of these placeholders at each batch, you can:

  1. First copy the values of these tensors into variables.
  2. Evaluate these variables in on_batch_end, and store the resulting arrays.

Now step 1 is a bit involved because you'll have to add an tf.assign op to the training function model.train_function. Using current Keras API, this can be done by providing a fetches argument to K.function() when the training function is constructed.

In model._make_train_function(), there's a line:

self.train_function = K.function(inputs,
                                 [self.total_loss] + self.metrics_tensors,
                                 updates=updates,
                                 name='train_function',
                                 **self._function_kwargs)

The fetches argument containing the tf.assign ops can be provided via model._function_kwargs (only works after Keras 2.1.0).

As an example:

from keras.layers import Dense
from keras.models import Sequential
from keras.callbacks import Callback
from keras import backend as K
import tensorflow as tf
import numpy as np

class CollectOutputAndTarget(Callback):
    def __init__(self):
        super(CollectOutputAndTarget, self).__init__()
        self.targets = []  # collect y_true batches
        self.outputs = []  # collect y_pred batches

        # the shape of these 2 variables will change according to batch shape
        # to handle the "last batch", specify `validate_shape=False`
        self.var_y_true = tf.Variable(0., validate_shape=False)
        self.var_y_pred = tf.Variable(0., validate_shape=False)

    def on_batch_end(self, batch, logs=None):
        # evaluate the variables and save them into lists
        self.targets.append(K.eval(self.var_y_true))
        self.outputs.append(K.eval(self.var_y_pred))

# build a simple model
# have to compile first for model.targets and model.outputs to be prepared
model = Sequential([Dense(5, input_shape=(10,))])
model.compile(loss='mse', optimizer='adam')

# initialize the variables and the `tf.assign` ops
cbk = CollectOutputAndTarget()
fetches = [tf.assign(cbk.var_y_true, model.targets[0], validate_shape=False),
           tf.assign(cbk.var_y_pred, model.outputs[0], validate_shape=False)]
model._function_kwargs = {'fetches': fetches}  # use `model._function_kwargs` if using `Model` instead of `Sequential`

# fit the model and check results
X = np.random.rand(10, 10)
Y = np.random.rand(10, 5)
model.fit(X, Y, batch_size=8, callbacks=[cbk])

Unless the number of samples can be divided by the batch size, the final batch will have a different size than other batches. So K.variable() and K.update() can't be used in this case. You'll have to use tf.Variable(..., validate_shape=False) and tf.assign(..., validate_shape=False) instead.


To verify the correctness of the saved arrays, you can add one line in training.py to print out the shuffled index array:

if shuffle == 'batch':
    index_array = _batch_shuffle(index_array, batch_size)
elif shuffle:
    np.random.shuffle(index_array)

print('Index array:', repr(index_array))  # Add this line

batches = _make_batches(num_train_samples, batch_size)

The shuffled index array should be printed out during fitting:

Epoch 1/1
Index array: array([8, 9, 3, 5, 4, 7, 1, 0, 6, 2])
10/10 [==============================] - 0s 23ms/step - loss: 0.5670

And you can check if cbk.targets is the same as Y[index_array]:

index_array = np.array([8, 9, 3, 5, 4, 7, 1, 0, 6, 2])
print(Y[index_array])
[[ 0.75325592  0.64857277  0.1926653   0.7642865   0.38901153]
 [ 0.77567689  0.13573623  0.4902501   0.42897559  0.55825652]
 [ 0.33760938  0.68195038  0.12303088  0.83509441  0.20991668]
 [ 0.98367778  0.61325065  0.28973401  0.28734073  0.93399794]
 [ 0.26097574  0.88219054  0.87951941  0.64887846  0.41996446]
 [ 0.97794604  0.91307569  0.93816428  0.2125808   0.94381495]
 [ 0.74813435  0.08036688  0.38094272  0.83178364  0.16713736]
 [ 0.52609421  0.39218962  0.21022047  0.58569125  0.08012982]
 [ 0.61276627  0.20679494  0.24124858  0.01262245  0.0994412 ]
 [ 0.6026137   0.25620512  0.7398164   0.52558182  0.09955769]]

print(cbk.targets)
[array([[ 0.7532559 ,  0.64857274,  0.19266529,  0.76428652,  0.38901153],
        [ 0.77567691,  0.13573623,  0.49025011,  0.42897558,  0.55825651],
        [ 0.33760938,  0.68195039,  0.12303089,  0.83509439,  0.20991668],
        [ 0.9836778 ,  0.61325067,  0.28973401,  0.28734073,  0.93399793],
        [ 0.26097575,  0.88219053,  0.8795194 ,  0.64887846,  0.41996446],
        [ 0.97794604,  0.91307569,  0.93816429,  0.2125808 ,  0.94381493],
        [ 0.74813437,  0.08036689,  0.38094273,  0.83178365,  0.16713737],
        [ 0.5260942 ,  0.39218962,  0.21022047,  0.58569127,  0.08012982]], dtype=float32),
 array([[ 0.61276627,  0.20679495,  0.24124858,  0.01262245,  0.0994412 ],
        [ 0.60261369,  0.25620511,  0.73981643,  0.52558184,  0.09955769]], dtype=float32)]

As you can see, there are two batches in cbk.targets (one "full batch" of size 8 and the final batch of size 2), and the row order is the same as Y[index_array].


From TF 2.2 on, you can use custom training steps rather than callbacks to achieve what you want. Here's a demo that works with tensorflow==2.2.0rc1, using inheritance to improve the keras.Sequential model. Performance-wise, this is not ideal as predictions are made twice, once in self(x, training=True) and once in super().train_step(data). But you get the idea.

This works in eager mode and does not use public APIs, so it should be pretty stable. One caveat is that you have to use tf.keras (standalone keras does not support Model.train_step), but I feel standalone keras is becoming more and more deprecated anyway.

"""Demonstrate access to Keras batch tensors in a tf.keras custom training step."""
import numpy as np
from tensorflow import keras
from tensorflow.keras import backend as K
from tensorflow.python.keras.engine import data_adapter

in_shape = (2,)
out_shape = (1,)
batch_size = 3
n_samples = 7


class SequentialWithPrint(keras.Sequential):
    def train_step(self, original_data):
        # Basically copied one-to-one from https://git.io/JvDTv
        data = data_adapter.expand_1d(original_data)
        x, y_true, w = data_adapter.unpack_x_y_sample_weight(data)
        y_pred = self(x, training=True)

        # this is pretty much like on_train_batch_begin
        K.print_tensor(w, "Sample weight (w) =")
        K.print_tensor(x, "Batch input (x) =")
        K.print_tensor(y_true, "Batch output (y_true) =")
        K.print_tensor(y_pred, "Prediction (y_pred) =")

        result = super().train_step(original_data)

        # add anything here for on_train_batch_end-like behavior

        return result


# Model
model = SequentialWithPrint([keras.layers.Dense(out_shape[0], input_shape=in_shape)])
model.compile(loss="mse", optimizer="adam")

# Example data
X = np.random.rand(n_samples, *in_shape)
Y = np.random.rand(n_samples, *out_shape)

model.fit(X, Y, batch_size=batch_size)
print("X: ", X)
print("Y: ", Y)

Finally, here is a very similar example that does not use inheritance:

"""Demonstrate access to Keras batch tensors in a tf.keras custom training step."""
import numpy as np
from tensorflow import keras
from tensorflow.keras import backend as K
from tensorflow.python.keras.engine import data_adapter

in_shape = (2,)
out_shape = (1,)
batch_size = 3
n_samples = 7


def make_print_data_and_train_step(keras_model):
    original_train_step = keras_model.train_step

    def print_data_and_train_step(original_data):
        # Basically copied one-to-one from https://git.io/JvDTv
        data = data_adapter.expand_1d(original_data)
        x, y_true, w = data_adapter.unpack_x_y_sample_weight(data)
        y_pred = keras_model(x, training=True)

        # this is pretty much like on_train_batch_begin
        K.print_tensor(w, "Sample weight (w) =")
        K.print_tensor(x, "Batch input (x) =")
        K.print_tensor(y_true, "Batch output (y_true) =")
        K.print_tensor(y_pred, "Prediction (y_pred) =")

        result = original_train_step(original_data)

        # add anything here for on_train_batch_end-like behavior

        return result

    return print_data_and_train_step


# Model
model = keras.Sequential([keras.layers.Dense(out_shape[0], input_shape=in_shape)])
model.train_step = make_print_data_and_train_step(model)
model.compile(loss="mse", optimizer="adam")

# Example data
X = np.random.rand(n_samples, *in_shape)
Y = np.random.rand(n_samples, *out_shape)

model.fit(X, Y, batch_size=batch_size)
print("X: ", X)
print("Y: ", Y)

Update: See my other answer for TF>=2.2.

One problem with @Yu-Yang's solution is that it relies on model._function_kwargs, which is not guaranteed to work as it is not part of the API. In particular, in TF2 with eager execution, session kwargs seem to be either not accepted at all or run preemptively due to eager mode.

Therefore, here is my solution tested on tensorflow==2.1.0. The trick is to replace fetches by a Keras metric, in which the assignment operations from fetches are made during training.

This even enables a Keras-only solution if the batch size divides the number of samples; otherwise, another trick has to be applied when initializing TensorFlow variables with a None shape, similar to validate_shape=False in earlier solutions (compare https://github.com/tensorflow/tensorflow/issues/35667).

Importantly, tf.keras behaves differently from keras (sometimes just ignoring assignments, or seeing variables as Keras symbolic tensors), so this updated solution takes care of both implementations (Keras==2.3.1 and tensorflow==2.1.0).

Update: This solution still works with tensorflow==2.2.0rc1 using Keras==2.3.1. However, I have not been able to get the targets with tf.keras yet since Sequential._targets is not available - the pain of using undocumented APIs. My other answer solves that problem.

"""Demonstrate access to Keras symbolic tensors in a (tf.)keras.Callback."""

import numpy as np
import tensorflow as tf

use_tf_keras = True
if use_tf_keras:
    from tensorflow import keras
    from tensorflow.keras import backend as K

    tf.config.experimental_run_functions_eagerly(False)
    compile_kwargs = {"run_eagerly": False, "experimental_run_tf_function": False}

else:
    import keras
    from keras import backend as K

    compile_kwargs = {}


in_shape = (2,)
out_shape = (1,)
batch_size = 3
n_samples = 7


class CollectKerasSymbolicTensorsCallback(keras.callbacks.Callback):
    """Collect Keras symbolic tensors."""

    def __init__(self):
        """Initialize intermediate variables for batches and lists."""
        super().__init__()

        # Collect batches here
        self.inputs = []
        self.targets = []
        self.outputs = []

        # # For a pure Keras solution, we need to know the shapes beforehand;
        # # in particular, batch_size must divide n_samples:
        # self.input = K.variable(np.empty((batch_size, *in_shape)))
        # self.target = K.variable(np.empty((batch_size, *out_shape)))
        # self.output = K.variable(np.empty((batch_size, *out_shape)))

        # If the shape of these variables will change (e.g., last batch), initialize
        # arbitrarily and specify `shape=tf.TensorShape(None)`:
        self.input = tf.Variable(0.0, shape=tf.TensorShape(None))
        self.target = tf.Variable(0.0, shape=tf.TensorShape(None))
        self.output = tf.Variable(0.0, shape=tf.TensorShape(None))

    def on_batch_end(self, batch, logs=None):
        """Evaluate the variables and save them into lists."""
        self.inputs.append(K.eval(self.input))
        self.targets.append(K.eval(self.target))
        self.outputs.append(K.eval(self.output))

    def on_train_end(self, logs=None):
        """Print all variables."""
        print("Inputs: ", *self.inputs)
        print("Targets: ", *self.targets)
        print("Outputs: ", *self.outputs)


@tf.function
def assign_keras_symbolic_tensors_metric(_foo, _bar):
    """
    Return the assignment operations as a metric to have them evaluated by Keras.

    This replaces `fetches` from the TF1/non-eager-execution solution.
    """
    # Collect assignments as list of (dest, src)
    assignments = (
        (callback.input, model.inputs[0]),
        (callback.target, model._targets[0] if use_tf_keras else model.targets[0]),
        (callback.output, model.outputs[0]),
    )
    for (dest, src) in assignments:
        dest.assign(src)

    return 0


callback = CollectKerasSymbolicTensorsCallback()
metrics = [assign_keras_symbolic_tensors_metric]

# Example model
model = keras.Sequential([keras.layers.Dense(out_shape[0], input_shape=in_shape)])
model.compile(loss="mse", optimizer="adam", metrics=metrics, **compile_kwargs)

# Example data
X = np.random.rand(n_samples, *in_shape)
Y = np.random.rand(n_samples, *out_shape)

model.fit(X, Y, batch_size=batch_size, callbacks=[callback])
print("X: ", X)
print("Y: ", Y)