Changing optimizer in keras during training

You can create an EarlyStopping callback that will stop the training, and in this callback, you create a function to change your optimizer and fit again.

The following callback will monitor the validation loss (val_loss) and stop training after two epochs (patience) without an improvement greater than min_delta.

min_delta = 0.000000000001

stopper = EarlyStopping(monitor='val_loss',min_delta=min_delta,patience=2) 

But for adding an extra action after the training is finished, we can extend this callback and change the on_train_end method:

class OptimizerChanger(EarlyStopping):

    def __init__(self, on_train_end, **kwargs):

        self.do_on_train_end = on_train_end
        super(OptimizerChanger,self).__init__(**kwargs)

    def on_train_end(self, logs=None):
        super(OptimizerChanger,self).on_train_end(self,logs)
        self.do_on_train_end()

For the custom function to call when the model ends training:

def do_after_training():

    #warining, this creates a new optimizer and,
    #at the beginning, it might give you a worse training performance than before
    model.compile(optimizer = 'SGD', loss=...., metrics = ...)
    model.fit(.....)

Now let's use the callbacks:

changer = OptimizerChanger(on_train_end= do_after_training, 
                           monitor='val_loss',
                           min_delta=min_delta,
                           patience=2)

model.fit(..., ..., callbacks = [changer])

i did this and it worked

class myCallback(tf.keras.callbacks.Callback):

    def on_epoch_end(self, epoch, logs):
    
        self.model.optimizer = new_model_optimizer
        self.model.loss = new_model_loss

Would something like this work ?

model.compile( optimizer='Adam', ...) 
model.fit( X, y, epochs=100, callback=[EarlyStoppingCallback] ) 
# now switch to SGD and finish training
model.compile( optimizer='SGD', ...) 
model.fit( X, y, epochs=10 ) 

Or would the second call to compile over-write all the variables (ie. do something like tf.initialize_all_variables()

(It's actually a followup question - but I'm writing this as an answer - because stackoverflow does not allow code in comments)

Tags:

Keras