Tensorflow tf.data.Dataset API, dataset unzip function?

My workaround was to just use map, not sure if there might be interest in a syntax sugar function for unzip later though.

a = dataset.map(lambda a, b: a)
b = dataset.map(lambda a, b: b)

TensorFlow's get_single_element() is finally around which can be used to unzip datasets (as asked in the question above).

This avoids the need of generating and using an iterator using .map() or iter() (which could be costly for big datasets).

get_single_element() returns a tensor (or a tuple or dict of tensors) encapsulating all the members of the dataset. We need to pass all the members of the dataset batched into a single element.

This can be used to get features as a tensor-array, or features and labels as a tuple or dictionary (of tensor-arrays) depending upon how the original dataset was created.

import tensorflow as tf

a = [ 1, 2, 3 ]
b = [ 4, 5, 6 ]
c = [ (7, 8), (9, 10), (11, 12) ]
d = [ 13, 14 ]
# Creating datasets from lists
ads = tf.data.Dataset.from_tensor_slices(a)
bds = tf.data.Dataset.from_tensor_slices(b)
cds = tf.data.Dataset.from_tensor_slices(c)
dds = tf.data.Dataset.from_tensor_slices(d)

list(tf.data.Dataset.zip((ads, bds)).as_numpy_iterator()) == [ (1, 4), (2, 5), (3, 6) ] # True
list(tf.data.Dataset.zip((bds, ads)).as_numpy_iterator()) == [ (4, 1), (5, 2), (6, 3) ] # True

# Let's zip and unzip ads and dds
x = tf.data.Dataset.zip((ads, dds))
xa, xd = tf.data.Dataset.get_single_element(x.batch(len(x)))
xa = list(xa.numpy())
xd = list(xd.numpy())
print(xa, xd) # [1,2] [13, 14] # notice how xa is now different from a because ads was curtailed when zip was done above.
d == xd # True