tensorflow dataset shuffle then batch or batch then shuffle

Fully agree to @mrry, but there exists one case where you might want to do batching before shuffling. Suppose you're processing some text data which will be feed into an RNN. Here each sentence is treated as one sequence, and one batch will contain multiple sequences. Since the length of sentences is variable, we need to pad the sentences in a batch to a uniform length. An efficient way to do this is to group sentences of similar length together through batching, and then do shuffling. Otherwise, we may end up batches which are full with the <pad> token.


TL;DR: Yes, there is a difference. Almost always, you will want to call Dataset.shuffle() before Dataset.batch(). There is no shuffle_batch() method on the tf.data.Dataset class, and you must call the two methods separately to shuffle and batch a dataset.


The transformations of a tf.data.Dataset are applied in the same sequence that they are called. Dataset.batch() combines consecutive elements of its input into a single, batched element in the output. We can see the effect of the order of operations by considering the following two datasets:

tf.enable_eager_execution()  # To simplify the example code.

# Batch before shuffle.
dataset = tf.data.Dataset.from_tensor_slices([0, 0, 0, 1, 1, 1, 2, 2, 2])
dataset = dataset.batch(3)
dataset = dataset.shuffle(9)

for elem in dataset:
  print(elem)

# Prints:
# tf.Tensor([1 1 1], shape=(3,), dtype=int32)
# tf.Tensor([2 2 2], shape=(3,), dtype=int32)
# tf.Tensor([0 0 0], shape=(3,), dtype=int32)

# Shuffle before batch.
dataset = tf.data.Dataset.from_tensor_slices([0, 0, 0, 1, 1, 1, 2, 2, 2])
dataset = dataset.shuffle(9)
dataset = dataset.batch(3)

for elem in dataset:
  print(elem)

# Prints:
# tf.Tensor([2 0 2], shape=(3,), dtype=int32)
# tf.Tensor([2 1 0], shape=(3,), dtype=int32)
# tf.Tensor([0 1 1], shape=(3,), dtype=int32)

In the first version (batch before shuffle), the elements of each batch are 3 consecutive elements from the input; whereas in the second version (shuffle before batch), they are randomly sampled from the input. Typically, when training by (some variant of) mini-batch stochastic gradient descent, the elements of each batch should be sampled as uniformly as possible from the total input. Otherwise, it is possible that the network will overfit to whatever structure was in the input data, and the resulting network will not achieve as high an accuracy.