Extract target from Tensorflow PrefetchDataset

You can turn use map to select either the input or label from every (input, label) pair, and turn this into a list:

import tensorflow as tf
import numpy as np

inputs = np.random.rand(100, 99)
targets = np.random.rand(100)

ds = tf.data.Dataset.from_tensor_slices((inputs, targets))

X_train = list(map(lambda x: x[0], ds))
y_train = list(map(lambda x: x[1], ds))

If you want to retain the batches or extract all the labels as a single tensor you could use the following function:


def get_labels_from_tfdataset(tfdataset, batched=False):

    labels = list(map(lambda x: x[1], tfdataset)) # Get labels 

    if not batched:
        return tf.concat(labels, axis=0) # concat the list of batched labels

    return labels

You can convert it to a list with list(ds) and then recompile it as a normal Dataset with tf.data.Dataset.from_tensor_slices(list(ds)). From there your nightmare begins again but at least it's a nightmare that other people have had before.

Note that for more complex datasets (e.g. nested dictionaries) you will need more preprocessing after calling list(ds), but this should work for the example you asked about.

This is far from a satisfying answer but unfortunately the class is entirely undocumented and none of the standard Dataset tricks work.


You can generate lists by looping your PrefetchDataset which is train_dataset in my example;

train_data = [(example.numpy(), label.numpy()) for example, label in train_dataset]

Thus you can reach every single example and label separately by using indexes;

train_data[0][0]
train_data[0][1]

You can also convert them into data frame with 2 columns by using pandas

import pandas as pd
pd.DataFrame(train_data, columns=['example', 'label'])

Then, if you want to convert back your list into PrefetchFataset, you can simply use ;

dataset = tf.data.Dataset.from_generator(
lambda: train_data, ( tf.string, tf.int32)) # you should define dtypes of yours

And you can check if it worked with this ;

list(dataset.as_numpy_iterator())