What does `tf.strided_slice()` do?

The conceptualization that really helped me understand this was that this function emulates the indexing behavior of numpy arrays.

If you're familiar with numpy arrays, you'll know that you can make slices via input[start1:end1:step1, start2:end2:step2, ... startN:endN:stepN]. Basically, a very succinct way of writing for loops to get certain elements of the array.

(If you're familiar with python indexing, you know that you can grab an array slice via input[start:end:step]. Numpy arrays, which may be nested, make use of the above tuple of slice objects.)

Well, strided_slice just allows you to do this fancy indexing without the syntactic sugar. The numpy example from above just becomes

# input[start1:end1:step1, start2:end2:step2, ... startN:endN:stepN]
tf.strided_slice(input, [start1, start2, ..., startN],
    [end1, end2, ..., endN], [step1, step2, ..., stepN])

The documentation is a bit confusing about this in the sense that:

a) begin - end is not strictly the shape of the return value:

The documentation claims otherwise, but this is only true if your strides are all ones. Examples:

rank1 = tf.constant(list(range(10)))
# The below op is basically:
# rank1[1:10:2] => [1, 3, 5, 7, 9]
tf.strided_slice(rank1, [1], [10], [2])

# [10,10] grid of the numbers from 0 to 99
rank2 = tf.constant([[i+j*10 for i in range(10)] for j in range(10)])
# The below op is basically:
# rank2[3:7:1, 5:10:2] => numbers 30 - 69, ending in 5, 7, or 9
sliced = tf.strided_slice(rank2, [3, 5], [7, 10], [1, 2])
# The below op is basically:
# rank2[3:7:1] => numbers 30 - 69
sliced = tf.strided_slice(rank2, [3], [7], [1]) 

b) it states that "begin, end, and strides will be all length n, where n is in general not the same dimensionality as input"

It sounds like dimensionality means rank here - but input does have to be a tensor of at least rank-n; it can't be lower (see rank-2 example above).

N.B. I've said nothing/not really explored the masking feature, but that seems beyond the scope of the question.


I experimented a bit with this method, which gave me some insights, which I think might be of some use. let's say we have a tensor.

a = np.array([[[1, 1.2, 1.3], [2, 2.2, 2.3], [7, 7.2, 7.3]],
              [[3, 3.2, 3.3], [4, 4.2, 4.3], [8, 8.2, 8.3]],
              [[5, 5.2, 5.3], [6, 6.2, 6.3], [9, 9.2, 9.3]]]) 
# a.shape = (3, 3, 3)

strided_slice() requires 4 required arguments input_, begin, end, strides in which we are giving our a as input_ argument. As the case with tf.slice() method, the begin argument is zero-based and rest of args shape-based. However in the docs begin and end both are zero-based.

The functionality of method is quite simple:
It works like iterating over a loop, where begin is the location of element in the tensor from where the loop initiates and end is where it stops.

tf.strided_slice(a, [0, 0, 0], [3, 3, 3], [1, 1, 1])

# output =  the tensor itself

tf.strided_slice(a, [0, 0, 0], [3, 3, 3], [2, 2, 2])

# output = [[[ 1.   1.3]
#            [ 7.   7.3]]
#           [[ 5.   5.3]
#            [ 9.   9.3]]]

strides are like steps over which the loop iterates, here the [2,2,2] makes method to produce values starting at (0,0,0), (0,0,2), (0,2,0), (0,2,2), (2,0,0), (2,0,2) ..... in the a tensor.

tf.strided_slice(input3, [1, 1, 0], [2, -1, 3], [1, 1, 1]) 

will produce output similar to tf.strided_slice(input3, [1, 1, 0], [2, 2, 3], [1, 1, 1]) as the tensora has shape = (3,3,3).


The mistake in your argument is the fact that you are directly adding the lists strides and begin element by element. This will make the function a lot less useful. Instead, it increments the begin list one dimension at a time, starting from the last dimension.

Let's solve the first example part by part. begin = [1, 0, 0] and end = [2, 1, 3]. Also, all the strides are 1. Work your way backwards, from the last dimension.

Start with element [1,0,0]. Now increase the last dimension only by its stride amount, giving you [1,0,1]. Keep doing this until you reach the limit. Something like [1,0,2], [1,0,3] (end of the loop). Now in your next iteration, start by incrementing the second to last dimension and resetting the last dimension, [1,1,0]. Here the second to last dimension is equal to end[1], so move to the first dimension (third to last) and reset the rest, giving you [2,0,0]. Again you are at the first dimension's limit, so quit the loop.

The following code is a recursive implementation of what I described above,

# Assume global `begin`, `end` and `stride`
def iterate(active, dim):
    if dim == len(begin):
        # last dimension incremented, work on the new matrix
        # Note that `active` and `begin` are lists
        new_matrix[active - begin] = old_matrix[active]
    else:
        for i in range(begin[dim], end[dim], stride[dim]):
            new_active = copy(active)
            new_active[dim] = i
            iterate(new_active, dim + 1)

iterate(begin, 0)