Interleaving tf.data.Datasets

MattScarpino is on the right track in his comment. You can use Dataset.zip() along with Dataset.flat_map() to flatten a multi-element dataset:

ds0 = tf.data.Dataset.range(0, 10, 2)
ds1 = tf.data.Dataset.range(1, 10, 2)

# Zip combines an element from each input into a single element, and flat_map
# enables you to map the combined element into two elements, then flattens the
# result.
dataset = tf.data.Dataset.zip((ds0, ds1)).flat_map(
    lambda x0, x1: tf.data.Dataset.from_tensors(x0).concatenate(
        tf.data.Dataset.from_tensors(x1)))

iter = dataset.make_one_shot_iterator()
val = iter.get_next()

Having said this, your intuition about using Dataset.interleave() is pretty sensible. We're investigating ways that you can do this more easily.


PS. As an alternative, you can use Dataset.interleave() to solve the problem if you change how ds0 and ds1 are defined:

dataset = tf.data.Dataset.range(2).interleave(
    lambda x: tf.data.Dataset.range(x, 10, 2), cycle_length=2, block_length=1)

Pavel's answer works great if you don't mind the order of interleaving. If you do care...

Option 1

A variant on mrry's answer that works with an arbitrary number of input datasets:

ds0 = tf.data.Dataset.range(0, 10, 3)
ds1 = tf.data.Dataset.range(1, 10, 3)
ds2 = tf.data.Dataset.range(2, 10, 3)
datasets = (ds0, ds1, ds2)

# Note: `datasets` should be a *tuple*, not a list.
datasets_zipped = tf.data.Dataset.zip(datasets)
# Each element of the dataset will now be a tuple, e.g. (0, 1, 2).
datasets_zipped_tensor = datasets_zipped.map(lambda *args: tf.stack(args))
# Each element will now be a Tensor, e.g. Tensor([0, 1, 2]).
datasets_interleaved = datasets_zipped_tensor.unbatch()

However, note that because of the way zip works, the dataset this produces limited by the length of the shortest input dataset. For example, using the above code with

datasets = [
    tf.data.Dataset.from_tensor_slices([1, 2, 3, 4, 5]),
    tf.data.Dataset.from_tensor_slices([10, 20]),
)

would yield a dataset comprising just [1, 10, 2, 20].

Option 2

Dataset.interleave doesn't suffer from this problem. In some cases you can use interleave with:

# Note: `datasets` should be a *list*, not a tuple
tf.data.Dataset.from_tensor_slices(datasets).interleave(lambda x: x)

But this doesn't seem to work for all kinds of dataset; calling from_tensor_slices on your datasets may not work.

Option 3

If option 2 doesn't work, you might be able to use interleave at an earlier stage of your dataset pipeline. For example, you might be able to change from calling interleave on pre-existing datasets to calling interleave on the file names from which the individual datasets were created:

filenames = ['foo', 'bar']
filesnames_dataset = tf.data.Dataset.from_tensor_slices(filenames)

def read_dataset(filename): ...

interleaved_dataset = filenames_dataset.interleave(read_dataset)

But this will only work if your read_dataset function accepts a Tensor argument.

Option 4

If none of the other options work for you, I think the only solution is to implement the interleaving yourself, with something like:

element_spec = datasets[0].element_spec
assert all(dataset.element_spec == element_spec for dataset in datasets)

def interleave_generator():
  iters_not_exhausted = [iter(dataset) for dataset in datasets]
  while iters_not_exhausted:
    for dataset_iter in iters_not_exhausted:
      try:
        x = next(dataset_iter)
      except StopIteration:
        iters_not_exhausted.remove(dataset_iter)
      else:
        yield x

datasets_interleaved = tf.data.Dataset.from_generator(
    interleave_generator,
    output_signature=element_spec,
)

tf.data.experimental.sample_from_datasets method could be also useful if you do not need to preserve the strict order for the items you want to interleave.

In my case I had to interleave a real life data with some synthetic data, so the order was not an issue for me. Then this can be easily done as follows

dataset = tf.data.experimental.sample_from_datasets([ds0, ds1])

Note that the result will be non-deterministic and some items could be taken from same dataset twice, but in general it will be very similar to regular interleave.

The advantages of this approach:

  • you can use multiple datasets in one method call
  • you can specify fraction of the samples for each dataset using weights parameter (e.g. I wanted to have only small fraction of the data to be generated so I used weights=[0.9, 0.1])

Tags:

Tensorflow