Save model every 10 epochs tensorflow.keras v2

Explicitly computing the number of batches per epoch worked for me.

BATCH_SIZE = 20
STEPS_PER_EPOCH = train_labels.size / BATCH_SIZE
SAVE_PERIOD = 10

# Create a callback that saves the model's weights every 10 epochs
cp_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_path, 
    verbose=1, 
    save_weights_only=True,
    save_freq= int(SAVE_PERIOD * STEPS_PER_EPOCH))

# Train the model with the new callback
model.fit(train_images, 
          train_labels,
          batch_size=BATCH_SIZE,
          steps_per_epoch=STEPS_PER_EPOCH,
          epochs=50, 
          callbacks=[cp_callback],
          validation_data=(test_images,test_labels),
          verbose=0)

The param period mentioned in the accepted answer is now not available anymore.

Using the save_freq param is an alternative, but risky, as mentioned in the docs; e.g., if the dataset size changes, it may become unstable: Note that if the saving isn't aligned to epochs, the monitored metric may potentially be less reliable (again taken from the docs).

Thus, I use a subclass as a solution:

class EpochModelCheckpoint(tf.keras.callbacks.ModelCheckpoint):

    def __init__(self,
                 filepath,
                 frequency=1,
                 monitor='val_loss',
                 verbose=0,
                 save_best_only=False,
                 save_weights_only=False,
                 mode='auto',
                 options=None,
                 **kwargs):
        super(EpochModelCheckpoint, self).__init__(filepath, monitor, verbose, save_best_only, save_weights_only,
                                                   mode, "epoch", options)
        self.epochs_since_last_save = 0
        self.frequency = frequency

    def on_epoch_end(self, epoch, logs=None):
        self.epochs_since_last_save += 1
        # pylint: disable=protected-access
        if self.epochs_since_last_save % self.frequency == 0:
            self._save_model(epoch=epoch, batch=None, logs=logs)

    def on_train_batch_end(self, batch, logs=None):
        pass

use it as

callbacks=[
     EpochModelCheckpoint("/your_save_location/epoch{epoch:02d}", frequency=10),
]

Note that, dependent on your TF version, you may have to change the args in the call to the superclass __init__.


Using tf.keras.callbacks.ModelCheckpoint use save_freq='epoch' and pass an extra argument period=10.

Although this is not documented in the official docs, that is the way to do it (notice it is documented that you can pass period, just doesn't explain what it does).