How can I merge multiple tfrecords file into one file?

Addressing the question title directly for anyone looking to merge multiple .tfrecord files:

The most convenient approach would be to use the tf.Data API: (adapting an example from the docs)

# Create dataset from multiple .tfrecord files
list_of_tfrecord_files = [dir1, dir2, dir3, dir4]
dataset = tf.data.TFRecordDataset(list_of_tfrecord_files)

# Save dataset to .tfrecord file
filename = 'test.tfrecord'
writer = tf.data.experimental.TFRecordWriter(filename)
writer.write(dataset)

However, as pointed out by holmescn, you'd likely be better off leaving the .tfrecord files as separate files and reading them together as a single tensorflow dataset.

You may also refer to a longer discussion regarding multiple .tfrecord files on Data Science Stackexchange


As the question is asked two months ago, I thought you already find the solution. For the follows, the answer is NO, you do not need to create a single HUGE tfrecord file. Just use the new DataSet API:

dataset = tf.data.TFRecordDataset(filenames_to_read,
    compression_type=None,    # or 'GZIP', 'ZLIB' if compress you data.
    buffer_size=10240,        # any buffer size you want or 0 means no buffering
    num_parallel_reads=os.cpu_count()  # or 0 means sequentially reading
)

# Maybe you want to prefetch some data first.
dataset = dataset.prefetch(buffer_size=batch_size)

# Decode the example
dataset = dataset.map(single_example_parser, num_parallel_calls=os.cpu_count())

dataset = dataset.shuffle(buffer_size=number_larger_than_batch_size)
dataset = dataset.batch(batch_size).repeat(num_epochs)
...

For details, check the document.


The answer by MoltenMuffins works for higher versions of tensorflow. However, if you are using lower versions, you have to iterate through the three tfrecords and save them them into a new record file as follows. This works for tf versions 1.0 and above.

def comb_tfrecord(tfrecords_path, save_path, batch_size=128):
        with tf.Graph().as_default(), tf.Session() as sess:
            ds = tf.data.TFRecordDataset(tfrecords_path).batch(batch_size)
            batch = ds.make_one_shot_iterator().get_next()
            writer = tf.python_io.TFRecordWriter(save_path)
            while True:
                try:
                    records = sess.run(batch)
                    for record in records:
                        writer.write(record)
                except tf.errors.OutOfRangeError:
                    break

Tags:

Merge