Numpy: find row-wise common element efficiently

Approach #1

Here's a vectorized one based on searchsorted2d -

# Sort each row of a and b in-place
a.sort(1)
b.sort(1)

# Use 2D searchsorted row-wise between a and b
idx = searchsorted2d(a,b)

# "Clip-out" out of bounds indices
idx[idx==a.shape[1]] = 0

# Get mask of valid ones i.e. matches
mask = np.take_along_axis(a,idx,axis=1)==b

# Use argmax to get first match as we know there's at most one match
match_val = np.take_along_axis(b,mask.argmax(1)[:,None],axis=1)

# Finally use np.where to choose between valid match 
# (decided by any one True in each row of mask)
out = np.where(mask.any(1)[:,None],match_val,np.nan)

Approach #2

Numba-based one for memory efficiency -

from numba import njit

@njit(parallel=True)
def numba_f1(a,b,out):
    n,a_ncols = a.shape
    b_ncols = b.shape[1]
    for i in range(n):
        for j in range(a_ncols):
            for k in range(b_ncols):
                m = a[i,j]==b[i,k]
                if m:
                    break
            if m:
                out[i] = a[i,j]
                break
    return out

def find_first_common_elem_per_row(a,b):
    out = np.full(len(a),np.nan)
    numba_f1(a,b,out)
    return out

Approach #3

Here's another vectorized one based on stacking and sorting -

r = np.arange(len(a))
ab = np.hstack((a,b))
idx = ab.argsort(1)
ab_s = ab[r[:,None],idx]
m = ab_s[:,:-1] == ab_s[:,1:]
m2 = (idx[:,1:]*m)>=a.shape[1]
m3 = m & m2
out = np.where(m3.any(1),b[r,idx[r,m3.argmax(1)+1]-a.shape[1]],np.nan)

Approach #4

For an elegant one, we can make use of broadcasting for a resource-hungry method -

m = (a[:,None]==b[:,:,None]).any(2)
out = np.where(m.any(1),b[np.arange(len(a)),m.argmax(1)],np.nan)

Doing some research, I found that checking whether two lists are disjoint runs in O(n+m), whereby n and m are the lengths of the lists (see here). The idea is that instertion and lookup of elements run in constant time for hash maps. Therefore, inserting all elements from the first list into a hashmap takes O(n) operations, and checking for each element in the second list whether it is already in the hash map takes O(m) operations. Therefore, solutions based on sorting, which run in O(n log(n) + m log(m)), are not optimal asymptotically.

Though the solutions by @Divakar are highly efficient in many use cases, they are less efficient, if the second dimension is large. Then, a solution based on hash maps is better suited. I have implemented it as follows in cython:

import numpy as np
cimport numpy as np
import cython
from libc.math cimport NAN
from libcpp.unordered_map cimport unordered_map
np.import_array()

@cython.boundscheck(False)
@cython.wraparound(False)
def get_common_element2d(np.ndarray[double, ndim=2] arr1, 
                         np.ndarray[double, ndim=2] arr2):

    cdef np.ndarray[double, ndim=1] result = np.empty(arr1.shape[0])
    cdef int dim1 = arr1.shape[1]
    cdef int dim2 = arr2.shape[1]
    cdef int i, j
    cdef unordered_map[double, int] tmpset = unordered_map[double, int]()

    for i in range(arr1.shape[0]):
        for j in range(dim1):
            # insert arr1[i, j] as key without assigned value
            tmpset[arr1[i, j]]
        for j in range(dim2):
            # check whether arr2[i, j] is in tmpset
            if tmpset.count(arr2[i,j]):
                result[i] = arr2[i,j]
                break
        else:
            result[i] = NAN
        tmpset.clear()

    return result

I have created test cases as follows:

import numpy as np
import timeit
from itertools import starmap
from mycythonmodule import get_common_element2d

m, n = 3000, 3000
a = np.random.rand(m, n)
b = np.random.rand(m, n)

for i, row in enumerate(a):
    if np.random.randint(2):
        common = np.random.choice(row, 1)
        b[i][np.random.choice(np.arange(n), np.random.randint(min(n,20)), False)] = common

# we need to copy the arrays on each test run, otherwise they 
# will remain sorted, which would bias the results

%timeit [set(aa).intersection(bb) for aa, bb in zip(a.copy(), b.copy())]
# returns 3.11 s ± 56.8 ms

%timeit list(starmap(np.intersect1d, zip(a.copy(), b.copy)))
# returns 1.83 s ± 55.4

# test sorting method
# divakarsMethod1 is the appraoch #1 in @Divakar's answer
%timeit divakarsMethod1(a.copy(), b.copy())
# returns 1.88 s ± 18 ms

# test hash map method
%timeit get_common_element2d(a.copy(), b.copy())
# returns 1.46 s ± 22.6 ms

These results seem to indicate that the naive approach is actually better than some vectorized versions. However, the vectorized algorithms play out their strengths, if many rows with fewer columns are considered (a different use case). In these cases, the vectorized approaches are more than 5 times faster than the naive appraoch and the sorting method turns out to be best.

Conclusion: I will go with the HashMap-based cython version, because it is among the most efficient variants in both use cases. If I had to set up cython first, I would use the sorting-based method.