More idiomatic way to display images in a grid with numpy

import numpy as np
import matplotlib.pyplot as plt

def gallery(array, ncols=3):
    nindex, height, width, intensity = array.shape
    nrows = nindex//ncols
    assert nindex == nrows*ncols
    # want result.shape = (height*nrows, width*ncols, intensity)
    result = (array.reshape(nrows, ncols, height, width, intensity)
              .swapaxes(1,2)
              .reshape(height*nrows, width*ncols, intensity))
    return result

def make_array():
    from PIL import Image
    return np.array([np.asarray(Image.open('face.png').convert('RGB'))]*12)

array = make_array()
result = gallery(array)
plt.imshow(result)
plt.show()

yields enter image description here


We have an array of shape (nrows*ncols, height, weight, intensity). We want an array of shape (height*nrows, width*ncols, intensity).

So the idea here is to first use reshape to split apart the first axis into two axes, one of length nrows and one of length ncols:

array.reshape(nrows, ncols, height, width, intensity)

This allows us to use swapaxes(1,2) to reorder the axes so that the shape becomes (nrows, height, ncols, weight, intensity). Notice that this places nrows next to height and ncols next to width.

Since reshape does not change the raveled order of the data, reshape(height*nrows, width*ncols, intensity) now produces the desired array.

This is (in spirit) the same as the idea used in the unblockshaped function.


Another way is to use view_as_blocks . Then you avoid to swap axes by hand :

from skimage.util import view_as_blocks
import numpy as np

def refactor(im_in,ncols=3):
    n,h,w,c = im_in.shape
    dn = (-n)%ncols # trailing images
    im_out = (np.empty((n+dn)*h*w*c,im_in.dtype)
           .reshape(-1,w*ncols,c))
    view=view_as_blocks(im_out,(h,w,c))
    for k,im in enumerate( list(im_in) + dn*[0] ):
        view[k//ncols,k%ncols,0] = im 
    return im_out