PyTorch tensor advanced indexing

You can specify the corresponding row index as:

import torch
x = torch.tensor([[1, 2, 3],
                  [4, 5, 6],
                  [7, 8, 9]])

y = torch.tensor([0, 2, 1])

x[range(x.shape[0]), y]
tensor([1, 6, 8])

Advanced indexing in pytorch works just as NumPy's, i.e the indexing arrays are broadcast together across the axes. So you could do as in FBruzzesi's answer.

Though similarly to np.take_along_axis, in pytorch you also have torch.gather, to take values along a specific axis:

x.gather(1, y.view(-1,1)).view(-1)
# tensor([1, 6, 8])