Save Keras ModelCheckpoints in Google Cloud Bucket

I faced a similar problem and the solution above didn't work for me. The file must be read and be written in binary form. Otherwise this error will be thrown.

UnicodeDecodeError: 'utf-8' codec can't decode byte 0x89 in position 0: invalid start byte

So the code will be

def copy_file_to_gcs(job_dir, file_path):
    with file_io.FileIO(file_path, mode='rb') as input_f:
        with file_io.FileIO(os.path.join(job_dir, file_path), mode='wb+') as output_f:
            output_f.write(input_f.read())

The issue can be solved with the following piece of code:

# Save Keras ModelCheckpoints locally
model.save('model.h5')

# Copy model.h5 over to Google Cloud Storage
with file_io.FileIO('model.h5', mode='r') as input_f:
    with file_io.FileIO('model.h5', mode='w+') as output_f:
        output_f.write(input_f.read())
        print("Saved model.h5 to GCS")

The model.h5 is saved on local filesystem and the copied over to GCS. As Jochen pointed out, there currently is no easy support to write HDF5 model checkpoints to GCS. With this hack it is possible to write the data until an easier solution is provided.


Here is my code that I wrote to save the model after each epoch.

import os
import numpy as np
import warnings
from keras.callbacks import ModelCheckpoint

class ModelCheckpointGC(ModelCheckpoint):
"""Taken from and modified:
https://github.com/keras-team/keras/blob/tf-keras/keras/callbacks.py
"""

def on_epoch_end(self, epoch, logs=None):
    logs = logs or {}
    self.epochs_since_last_save += 1
    if self.epochs_since_last_save >= self.period:
        self.epochs_since_last_save = 0
        filepath = self.filepath.format(epoch=epoch, **logs)
        if self.save_best_only:
            current = logs.get(self.monitor)
            if current is None:
                warnings.warn('Can save best model only with %s available, '
                              'skipping.' % (self.monitor), RuntimeWarning)
            else:
                if self.monitor_op(current, self.best):
                    if self.verbose > 0:
                        print('Epoch %05d: %s improved from %0.5f to %0.5f,'
                              ' saving model to %s'
                              % (epoch, self.monitor, self.best,
                                 current, filepath))
                    self.best = current
                    if self.save_weights_only:
                        self.model.save_weights(filepath, overwrite=True)
                    else:
                        if is_development():
                            self.model.save(filepath, overwrite=True)
                        else:
                            self.model.save(filepath.split(
                                "/")[-1])
                            with file_io.FileIO(filepath.split(
                                    "/")[-1], mode='rb') as input_f:
                                with file_io.FileIO(filepath, mode='wb+') as output_f:
                                    output_f.write(input_f.read())
                else:
                    if self.verbose > 0:
                        print('Epoch %05d: %s did not improve' %
                              (epoch, self.monitor))
        else:
            if self.verbose > 0:
                print('Epoch %05d: saving model to %s' % (epoch, filepath))
            if self.save_weights_only:
                self.model.save_weights(filepath, overwrite=True)
            else:
                if is_development():
                    self.model.save(filepath, overwrite=True)
                else:
                    self.model.save(filepath.split(
                        "/")[-1])
                    with file_io.FileIO(filepath.split(
                            "/")[-1], mode='rb') as input_f:
                        with file_io.FileIO(filepath, mode='wb+') as output_f:
                            output_f.write(input_f.read())

There is a function is_development() that checks if it is the local or gcloud environment. On the local environment I did set the variable LOCAL_ENV=1:

def is_development():
    """check if the environment is local or in the gcloud
    created the local variable in bash profile
    export LOCAL_ENV=1

    Returns:
        [boolean] -- True if local env
    """
    try:
        if os.environ['LOCAL_ENV'] == '1':
            return True
        else:
            return False
    except:
        return False

Then you can use it:

 ModelCheckpointGC(
            'gs://your_bucket/models/model.h5',
            monitor='loss',
            verbose=1,
            save_best_only=True,
            mode='min'))

I hope that helps someone and saves some time.


I might be a bit late on this, but for the sake of future visitors I would describe the whole process of how to adapt the code that was previously run locally to be GoogleML-aware from the IO point of view.

  1. Python standard open(file_name, mode) does not work with buckets (gs://...../file_name). One needs to from tensorflow.python.lib.io import file_io and change all calls to open(file_name, mode) to file_io.FileIO(file_name, mode=mode) (note the named mode parameter). The interface of the opened handle is the same.
  2. Keras and/or other libraries mostly use standard open(file_name, mode) internally. That said, something like trained_model.save(file_path) calls to 3rd-party libraries will fail to store the result to the bucket. The only way to retrieve a model after the job has finished successfully would be to store it locally and then move to the bucket.

The code below is quite inefficient, because it loads the whole model at once and then dumps it to the bucket, but it worked for me for relatively small models:

model.save(file_path)

with file_io.FileIO(file_path, mode='rb') as if:
    with file_io.FileIO(os.path.join(model_dir, file_path), mode='wb+') as of:
        of.write(if.read())

The mode must be set to binary for both reading and writing.

When the file is relatively big, it makes sense to read and write it in chunks to decrease memory consumption.

  1. Before running a real task, I would advise to run a stub that simply saves a file to remote bucket.

This implementation, temporarily put instead of real train_model call, should do:

if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    parser.add_argument(
        '--job-dir',
        help='GCS location with read/write access',
        required=True
    )

    args = parser.parse_args()
    arguments = args.__dict__
    job_dir = arguments.pop('job_dir')

    with file_io.FileIO(os.path.join(job_dir, "test.txt"), mode='wb+') as of:
        of.write("Test passed.")

After a successful execution you should see the file test.txt with a content "Test passed." in your bucket.