Keras model.fit() with tf.dataset API + validation_data

The way to connect a reinitializable iterator to a Keras model is to plug in an Iterator that returns both the x and y values concurrently:

sess = tf.Session()
keras.backend.set_session(sess) 

x = np.random.random((5, 2))
y = np.array([0, 1] * 3 + [1, 0] * 2).reshape(5, 2) # One hot encoded
input_dataset = tf.data.Dataset.from_tensor_slices((x, y))

# Create your reinitializable_iterator and initializer
reinitializable_iterator = tf.data.Iterator.from_structure(input_dataset.output_types, input_dataset.output_shapes)
init_op = reinitializable_iterator.make_initializer(input_dataset)

#run the initializer
sess.run(init_op) # feed_dict if you're using placeholders as input

# build keras model and plug in the iterator
model = keras.Model.model(...)
model.compile(...)
model.fit(reinitializable_iterator,...)

If you also have a validation dataset, the easiest thing to do is to just create a separate iterator and plug it in the validation_data parameter. Make sure to define your steps_per_epoch and validation_steps since they cannot be inferred.


I solved the problem by using fit_genertor. I found the solution here. I applied @Dat-Nguyen's solution.

You need simply to create two iterators, one for training and one for validation and then create your own generator where you will extract batches from the dataset and provide the data in form of (batch_data, batch_labels) . Finally in model.fit_generator you will pass the train_generator and validation_generator.