How to Properly Combine TensorFlow's Dataset API and Keras?

Update June 09, 2018

  • Starting from Tensorflow 1.9, one can pass tf.data.Dataset object directly into keras.Model.fit() and it would act similar to fit_generator.
  • A complete example can be found on this gist.
# Load mnist training data
(x_train, y_train), _ = tf.keras.datasets.mnist.load_data()
training_set = tfdata_generator(x_train, y_train,is_training=True)

model = # your keras model here              
model.fit(
    training_set.make_one_shot_iterator(),
    steps_per_epoch=len(x_train) // 128,
    epochs=5,
    verbose = 1)
  • tfdata_generator is a function that returns an iterable tf.data.Dataset.
def tfdata_generator(images, labels, is_training, batch_size=128):
  '''Construct a data generator using `tf.Dataset`. '''

  def map_fn(image, label):
      '''Preprocess raw data to trainable input. '''
    x = tf.reshape(tf.cast(image, tf.float32), (28, 28, 1))
    y = tf.one_hot(tf.cast(label, tf.uint8), _NUM_CLASSES)
    return x, y

  dataset = tf.data.Dataset.from_tensor_slices((images, labels))

  if is_training:
    dataset = dataset.shuffle(1000)  # depends on sample size
  dataset = dataset.map(map_fn)
  dataset = dataset.batch(batch_size)
  dataset = dataset.repeat()
  dataset = dataset.prefetch(tf.contrib.data.AUTOTUNE)

  return dataset

Old Solution:

In addition to @Yu-Yang's answer, you can also modify tf.data.Dataset to become a generator for fit_generator as following

from tensorflow.contrib.learn.python.learn.datasets import mnist

data   = mnist.load_mnist()
model  = # your Keras model
model.fit_generator(generator = tfdata_generator(data.train.images, data.train.labels),
                    steps_per_epoch=200,
                    workers = 0 , # This is important
                    verbose = 1)


def tfdata_generator(images, labels, batch_size=128, shuffle=True,):
    def map_func(image, label):
        '''A transformation function'''
        x_train = tf.reshape(tf.cast(image, tf.float32), image_shape)
        y_train = tf.one_hot(tf.cast(label, tf.uint8), num_classes)
        return [x_train, y_train]

    dataset  = tf.data.Dataset.from_tensor_slices((images, labels))
    dataset  = dataset.map(map_func)
    dataset  = dataset.shuffle().batch(batch_size).repeat()
    iterator = dataset.make_one_shot_iterator()

    next_batch = iterator.get_next()
    while True:
        yield K.get_session().run(next_batch)

There is indeed a more efficient way to use Dataset without having to convert the tensors into numpy arrays. However, it is not (yet?) on the official documentation. From the release note, it's a feature introduced in Keras 2.0.7. You may have to install keras>=2.0.7 in order to use it.

x = np.arange(4).reshape(-1, 1).astype('float32')
ds_x = Dataset.from_tensor_slices(x).repeat().batch(4)
it_x = ds_x.make_one_shot_iterator()

y = np.arange(5, 9).reshape(-1, 1).astype('float32')
ds_y = Dataset.from_tensor_slices(y).repeat().batch(4)
it_y = ds_y.make_one_shot_iterator()

input_vals = Input(tensor=it_x.get_next())
output = Dense(1, activation='relu')(input_vals)
model = Model(inputs=input_vals, outputs=output)
model.compile('rmsprop', 'mse', target_tensors=[it_y.get_next()])
model.fit(steps_per_epoch=1, epochs=5, verbose=2)

Several differences:

  1. Supply the tensor argument to the Input layer. Keras will read values from this tensor, and use it as the input to fit the model.
  2. Supply the target_tensors argument to Model.compile().
  3. Remember to convert both x and y into float32. Under normal usage, Keras will do this conversion for you. But now you'll have to do it yourself.
  4. Batch size is specified during the construction of Dataset. Use steps_per_epoch and epochs to control when to stop model fitting.

In short, use Input(tensor=...), model.compile(target_tensors=...) and model.fit(x=None, y=None, ...) if your data are to be read from tensors.


The other answers are good, however it is important to note that using from_tensor_slices directly with large numpy arrays can quickly fill up your memory as, IIRC, the values are copied into the graph as tf.constants. In my experience, this will cause a silent failure where training will eventually start but no improvement in loss etc will occur.

A better way is to use placeholders. E.g. here is my code to create a generator for images and their onehot targets:

def create_generator_tf_dataset(self, images, onehots, batch_size):
    # Get shapes
    img_size = images.shape
    img_size = (None, img_size[1], img_size[2], img_size[3])
    onehot_size = onehots.shape
    onehot_size = (None, onehot_size[1])

    # Placeholders
    images_tensor = tf.placeholder(tf.float32, shape=img_size)
    onehots_tensor = tf.placeholder(tf.float32, shape=onehot_size)

    # Dataset
    dataset = tf.data.Dataset.from_tensor_slices((images_tensor, onehots_tensor))
    # Map function (e.g. augmentation)
    if map_fn is not None:
        dataset = dataset.map(lambda x, y: (map_fn(x), y), num_parallel_calls=tf.data.experimental.AUTOTUNE)
    # Combined shuffle and infinite repeat
    dataset = dataset.apply(
        tf.data.experimental.shuffle_and_repeat(len(images), None))  
    dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(1)

    # Make the iterator
    iterator = dataset.make_initializable_iterator()
    init_op = iterator.initializer
    next_val = iterator.get_next()

    with K.get_session().as_default() as sess:
        sess.run(init_op, feed_dict={images_tensor: images, onehots_tensor: onehots})
        while True:
            inputs, labels = sess.run(next_val)
            yield inputs, labels