numpy random shuffle by row independently

import numpy as np
np.random.seed(2018)

def scramble(a, axis=-1):
    """
    Return an array with the values of `a` independently shuffled along the
    given axis
    """ 
    b = a.swapaxes(axis, -1)
    n = a.shape[axis]
    idx = np.random.choice(n, n, replace=False)
    b = b[..., idx]
    return b.swapaxes(axis, -1)

a = a = np.arange(4*9).reshape(4, 9)
# array([[ 0,  1,  2,  3,  4,  5,  6,  7,  8],
#        [ 9, 10, 11, 12, 13, 14, 15, 16, 17],
#        [18, 19, 20, 21, 22, 23, 24, 25, 26],
#        [27, 28, 29, 30, 31, 32, 33, 34, 35]])

print(scramble(a, axis=1))

yields

[[ 3  8  7  0  4  5  1  2  6]
 [12 17 16  9 13 14 10 11 15]
 [21 26 25 18 22 23 19 20 24]
 [30 35 34 27 31 32 28 29 33]]

while scrambling along the 0-axis:

print(scramble(a, axis=0))

yields

[[18 19 20 21 22 23 24 25 26]
 [ 0  1  2  3  4  5  6  7  8]
 [27 28 29 30 31 32 33 34 35]
 [ 9 10 11 12 13 14 15 16 17]]

This works by first swapping the target axis with the last axis:

b = a.swapaxes(axis, -1)

This is a common trick used to standardize code which deals with one axis. It reduces the general case to the specific case of dealing with the last axis. Since in NumPy version 1.10 or higher swapaxes returns a view, there is no copying involved and so calling swapaxes is very quick.

Now we can generate a new index order for the last axis:

n = a.shape[axis]
idx = np.random.choice(n, n, replace=False)

Now we can shuffle b (independently along the last axis):

b = b[..., idx]

and then reverse the swapaxes to return an a-shaped result:

return b.swapaxes(axis, -1)

Tags:

Shuffle

Numpy