Indexing the max elements in a multidimensional tensor in PyTorch

An ugly hackaround is to create a binary mask out of idx and use it to index the arrays. The basic code looks like this:

import torch
torch.manual_seed(0)

A = torch.randn((5, 2, 3))
_, idx = torch.max(A, dim=2)

mask = torch.arange(A.size(2)).reshape(1, 1, -1) == idx.unsqueeze(2)
B = torch.zeros_like(A)
B[mask] = A[mask]
print(A)
print(B)

The trick is that torch.arange(A.size(2)) enumerates the possible values in idx and mask is nonzero in places where they equal the idx. Remarks:

  1. If you really discard the first output of torch.max, you can use torch.argmax instead.
  2. I assume that this is a minimal example of some wider problem, but be aware that you are currently reinventing torch.nn.functional.max_pool3d with kernel of size (1, 1, 3).
  3. Also, be aware that in-place modification of tensors with masked assignment can cause issues with autograd, so you may want to use torch.where as shown here.

I would expect that somebody comes up with a cleaner solution (avoiding the intermedia allocation of the mask array), likely making use of torch.index_select, but I can't get it to work right now.


You can use torch.meshgrid to create an index tuple:

>>> index_tuple = torch.meshgrid([torch.arange(x) for x in A.size()[:-1]]) + (idx,)
>>> B = torch.zeros_like(A)
>>> B[index_tuple] = A[index_tuple]

Note that you can also mimic meshgrid via (for the specific case of 3D):

>>> index_tuple = (
...     torch.arange(A.size(0))[:, None],
...     torch.arange(A.size(1))[None, :],
...     idx
... )

Bit more explanation:
We will have the indices something like this:

In [173]: idx 
Out[173]: 
tensor([[2, 1],
        [2, 0],
        [2, 1],
        [2, 2],
        [2, 2]])

From this, we want to go to three indices (since our tensor is 3D, we need three numbers to retrieve each element). Basically we want to build a grid in the first two dimensions, as shown below. (And that's why we use meshgrid).

In [174]: A[0, 0, 2], A[0, 1, 1]  
Out[174]: (tensor(0.6288), tensor(-0.3070))

In [175]: A[1, 0, 2], A[1, 1, 0]  
Out[175]: (tensor(1.7085), tensor(0.7818))

In [176]: A[2, 0, 2], A[2, 1, 1]  
Out[176]: (tensor(0.4823), tensor(1.1199))

In [177]: A[3, 0, 2], A[3, 1, 2]    
Out[177]: (tensor(1.6903), tensor(1.0800))

In [178]: A[4, 0, 2], A[4, 1, 2]          
Out[178]: (tensor(0.9138), tensor(0.1779))

In the above 5 lines, the first two numbers in the indices are basically the grid that we build using meshgrid and the third number is coming from idx.

i.e. the first two numbers form a grid.

 (0, 0) (0, 1)
 (1, 0) (1, 1)
 (2, 0) (2, 1)
 (3, 0) (3, 1)
 (4, 0) (4, 1)