In TensorFlow, what is the argument 'axis' in the function 'tf.one_hot'

Here's an example:

x = tf.constant([0, 1, 2])

... is the input tensor and N=4 (each index is transformed into 4D vector).

axis=-1

Computing one_hot_1 = tf.one_hot(x, 4).eval() yields a (3, 4) tensor:

[[ 1.  0.  0.  0.]
 [ 0.  1.  0.  0.]
 [ 0.  0.  1.  0.]]

... where the last dimension is one-hot encoded (clearly visible). This corresponds to the default axis=-1, i.e. the last one.

axis=0

Now, computing one_hot_2 = tf.one_hot(x, 4, axis=0).eval() yields a (4, 3) tensor, which is not immediately recognizable as one-hot encoded:

[[ 1.  0.  0.]
 [ 0.  1.  0.]
 [ 0.  0.  1.]
 [ 0.  0.  0.]]

This is because the one-hot encoding is done along the 0-axis and one has to transpose the matrix to see the previous encoding. The situation becomes more complicated, when the input is higher dimensional, but the idea is the same: the difference is in placement of the extra dimension used for one-hot encoding.


For me axis translates to "where do you add the additional numbers to increase the dimension". At least this is how I am interpreting it and serves as a mnemonic.

For instance you have [1,2,3,0,2,1] and this is of shape (1,6). Which means it's a one dimension array. one_hot adds zeros and transform the position to a 1 in every position of your original array, for this to happen the original array must have 1 more dimension than the original array and axis tells the function where to add it, this new dimension will identify the examples.


axis=1

You add a second dimension and the first dimension is kept. This would result in a (6,4) array. So for the resulting array, you use the first dimension (0) to know which example you see and the second dimension (1, the new one) to know if that class is active. newArr[0][1]=1 means example 0, class 1, which in this case means example 0 is of class 1.
   0   1   2   3  <- class

[[ 0.  1.  0.  0.]   <- example 0
 [ 0.  0.  1.  0.]   <- example 1
 [ 0.  0.  0.  1.]   <- example 2
 [ 1.  0.  0.  0.]   <- example 3
 [ 0.  0.  1.  0.]   <- example 4
 [ 0.  1.  0.  0.]]  <- example 5

axis=0

You add a first dimension and the existing dimension is shifted. This would result in a (4,6) array. So for the resulting array, you use the first dimension (0, the new dimension) to know if that class is active and the second dimension (1) to know which example you see. newArr[0][1]=0 means class 0, example 1, which in this case means example 1 is not of class 0.
   0   1   2   3   4   5  <- example

[[ 0.  0.  0.  1.  0.  0.]   <- class 0
 [ 1.  0.  0.  0.  0.  1.]   <- class 1
 [ 0.  1.  0.  0.  1.  0.]   <- class 2
 [ 0.  0.  1.  0.  0.  0.]]  <- class 3