Efficiently compute pairwise squared Euclidean distance in Matlab

For squared Euclidean distance one can also use the following formula

||a-b||^2 = ||a||^2 + ||b||^2 - 2<a,b>

Where <a,b> is the dot product between a and b

nA = sum( A.^2, 2 ); %// norm of A's elements
nB = sum( B.^2, 2 ); %// norm of B's elements
distMat = bsxfun( @plus, nA, nB' ) - 2 * A * B' ;

Recently, I've been told that as of R2016b this method for computing square Euclidean distance is faster than accepted method.


The usually given answer here is based on bsxfun (cf. e.g. [1]). My proposed approach is based on matrix multiplication and turns out to be much faster than any comparable algorithm I could find:

helpA = zeros(numA,3*d);
helpB = zeros(numB,3*d);
for idx = 1:d
    helpA(:,3*idx-2:3*idx) = [ones(numA,1), -2*A(:,idx), A(:,idx).^2 ];
    helpB(:,3*idx-2:3*idx) = [B(:,idx).^2 ,    B(:,idx), ones(numB,1)];
end
distMat = helpA * helpB';

Please note: For constant d one can replace the for-loop by hardcoded implementations, e.g.

helpA(:,3*idx-2:3*idx) = [ones(numA,1), -2*A(:,1), A(:,1).^2, ... % d == 2
                          ones(numA,1), -2*A(:,2), A(:,2).^2 ];   % etc.

Evaluation:

%% create some points
d = 2; % dimension
numA = 20000;
numB = 20000;
A = rand(numA,d);
B = rand(numB,d);

%% pairwise distance matrix
% proposed method:
tic;
helpA = zeros(numA,3*d);
helpB = zeros(numB,3*d);
for idx = 1:d
    helpA(:,3*idx-2:3*idx) = [ones(numA,1), -2*A(:,idx), A(:,idx).^2 ];
    helpB(:,3*idx-2:3*idx) = [B(:,idx).^2 ,    B(:,idx), ones(numB,1)];
end
distMat = helpA * helpB';
toc;

% compare to pdist2:
tic;
pdist2(A,B).^2;
toc;

% compare to [1]:
tic;
bsxfun(@plus,dot(A,A,2),dot(B,B,2)')-2*(A*B');
toc;

% Another method: added 07/2014
% compare to ndgrid method (cf. Dan's comment)
tic;
[idxA,idxB] = ndgrid(1:numA,1:numB);
distMat = zeros(numA,numB);
distMat(:) = sum((A(idxA,:) - B(idxB,:)).^2,2);
toc;

Result:

Elapsed time is 1.796201 seconds.
Elapsed time is 5.653246 seconds.
Elapsed time is 3.551636 seconds.
Elapsed time is 22.461185 seconds.

For a more detailed evaluation w.r.t. dimension and number of data points follow the discussion below (@comments). It turns out that different algos should be preferred in different settings. In non time critical situations just use the pdist2 version.

Further development: One can think of replacing the squared euclidean by any other metric based on the same principle:

help = zeros(numA,numB,d);
for idx = 1:d
    help(:,:,idx) = [ones(numA,1), A(:,idx)     ] * ...
                    [B(:,idx)'   ; -ones(1,numB)];
end
distMat = sum(ANYFUNCTION(help),3);

Nevertheless, this is quite time consuming. It could be useful to replace for smaller d the 3-dimensional matrix help by d 2-dimensional matrices. Especially for d = 1 it provides a method to compute the pairwise difference by a simple matrix multiplication:

pairDiffs = [ones(numA,1), A ] * [B'; -ones(1,numB)];

Do you have any further ideas?