TensorFlow image operations for batches

One possibility is to use the recently added tf.map_fn() to apply the single-image operator to each element of the batch.

result = tf.map_fn(lambda img: tf.image.random_flip_left_right(img), images)

This effectively builds the same graph as keveman suggests building, but it can be more efficient for larger batch sizes, by using TensorFlow's support for loops.


You can call the image operation in a loop and concatenate the result. For example :

transformed_images = []
for i in range(batch_size):
  transformed_images.append(tf.image.random_flip_left_right(image[i, :, :, :]))
retsult = tf.stack(transformed_images)

TLDR: you can create queue, define reading and processing data for single element of queue and than make batch - all this with TF methods.

I'm not sure how it works but if you use queues and create batches and read images with tensorflow methods you can work with batch as with single image.

I didn't test it on large datasets yet and don't know how good it is (speed, memory consumption and so on). May be for now it's better to create batch by yourself.

I have seen this in cifar10 example. You can see it here https://github.com/tensorflow/tensorflow/tree/r0.10/tensorflow/models/image/cifar10

  1. Firstly they create queue with tf.train.string_input_producer. https://github.com/tensorflow/tensorflow/blob/r0.10/tensorflow/models/image/cifar10/cifar10_input.py#L222 You can use different type of queue. For example I try to use tf.train.slice_input_producer for multiple images. You can read about it here Tensorflow read images with labels
  2. Then they make all needed operations as for single image. If they need only reading it is just reading, if they want processing they crop image and do other stuff. Reading is described in read_cifar10. Processing in distorted_inputs, it is here https://github.com/tensorflow/tensorflow/blob/r0.10/tensorflow/models/image/cifar10/cifar10_input.py#L138
  3. They pass results of 2 to tf.train.batch or tf.train.shuffle_batch depending on the parameters and return it from inputs() and distorted_inputs() functions.
  4. They read it just like images, labels = cifar10.distorted_inputs() and do following job. It's here https://github.com/tensorflow/tensorflow/blob/r0.10/tensorflow/models/image/cifar10/cifar10_train.py#L66

Tags:

Tensorflow