How to select rows from a 3-D Tensor in TensorFlow?

This is possible in TensorFlow, but slightly inconvenient, because tf.gather() currently only works with one-dimensional indices, and only selects slices from the 0th dimension of a tensor. However, it is still possible to solve your problem efficiently, by transforming the arguments so that they can be passed to tf.gather():

logits = ... # [2 x 4 x 4] tensor
indices = tf.constant([[0, 1], [1, 3]])

# Use tf.shape() to make this work with dynamic shapes.
batch_size = tf.shape(logits)[0]
rows_per_batch = tf.shape(logits)[1]
indices_per_batch = tf.shape(indices)[1]

# Offset to add to each row in indices. We use `tf.expand_dims()` to make 
# this broadcast appropriately.
offset = tf.expand_dims(tf.range(0, batch_size) * rows_per_batch, 1)

# Convert indices and logits into appropriate form for `tf.gather()`. 
flattened_indices = tf.reshape(indices + offset, [-1])
flattened_logits = tf.reshape(logits, tf.concat(0, [[-1], tf.shape(logits)[2:]]))

selected_rows = tf.gather(flattened_logits, flattened_indices)

result = tf.reshape(selected_rows,
                    tf.concat(0, [tf.pack([batch_size, indices_per_batch]),
                                  tf.shape(logits)[2:]]))

Note that, since this uses tf.reshape() and not tf.transpose(), it doesn't need to modify the (potentially large) data in the logits tensor, so it should be fairly efficient.


mrry's answer is great, but I think with the function tf.gather_nd the problem can be solved with much fewer lines of code (probably this function was not yet available at the time of mrry's writing):

logits = tf.constant([[[10.0, 10.0, 20.0, 20.0],
                      [11.0, 10.0, 10.0, 30.0],
                      [12.0, 10.0, 10.0, 20.0],
                      [13.0, 10.0, 10.0, 20.0]],
                     [[14.0, 11.0, 21.0, 31.0],
                      [15.0, 11.0, 11.0, 21.0],
                      [16.0, 11.0, 11.0, 21.0],
                      [17.0, 11.0, 11.0, 21.0]]])

indices = tf.constant([[[0, 0], [0, 1]], [[1, 1], [1, 3]]])

result = tf.gather_nd(logits, indices)
with tf.Session() as sess:
    print(sess.run(result))

This will print

[[[ 10.  10.  20.  20.]
  [ 11.  10.  10.  30.]]

 [[ 15.  11.  11.  21.]
  [ 17.  11.  11.  21.]]]

tf.gather_nd should be available as of v0.10. Check out this github issue for more discussions on this.

Tags:

Tensorflow