Fast alternative for numpy.median.reduceat

Here's a NumPy based approach to get binned-median for positive bins/index values -

def bin_median(a, i):
    sidx = np.lexsort((a,i))

    a = a[sidx]
    i = i[sidx]

    c = np.bincount(i)
    c = c[c!=0]

    s1 = c//2

    e = c.cumsum()
    s1[1:] += e[:-1]

    firstval = a[s1-1]
    secondval = a[s1]
    out = np.where(c%2,secondval,(firstval+secondval)/2.0)
    return out

To solve our specific case of subtracted ones -

def bin_median_subtract(a, i):
    sidx = np.lexsort((a,i))

    c = np.bincount(i)

    valid_mask = c!=0
    c = c[valid_mask]    

    e = c.cumsum()
    s1 = c//2
    s1[1:] += e[:-1]
    ssidx = sidx.argsort()
    starts = c%2+s1-1
    ends = s1

    starts_orgindx = sidx[np.searchsorted(sidx,starts,sorter=ssidx)]
    ends_orgindx  = sidx[np.searchsorted(sidx,ends,sorter=ssidx)]
    val = (a[starts_orgindx] + a[ends_orgindx])/2.
    out = a-np.repeat(val,c)
    return out

Sometimes you need to write non-idiomatic numpy code if you really want to speed up your calculation which you can't do with native numpy.

numba compiles your python code to low-level C. Since a lot of numpy itself is usually as fast as C, this mostly ends up being useful if your problem doesn't lend itself to native vectorization with numpy. This is one example (where I assumed that the indices are contiguous and sorted, which is also reflected in the example data):

import numpy as np
import numba

# use the inflated example of roganjosh https://stackoverflow.com/a/58788534
data =  [1.00, 1.05, 1.30, 1.20, 1.06, 1.54, 1.33, 1.87, 1.67]
index = [0,    0,    1,    1,    1,    1,    2,    3,    3] 

data = np.array(data * 500) # using arrays is important for numba!
index = np.sort(np.random.randint(0, 30, 4500))               

# jit-decorate; original is available as .py_func attribute
@numba.njit('f8[:](f8[:], i8[:])') # explicit signature implies ahead-of-time compile
def diffmedian_jit(data, index): 
    res = np.empty_like(data) 
    i_start = 0 
    for i in range(1, index.size): 
        if index[i] == index[i_start]: 
            continue 

        # here: i is the first _next_ index 
        inds = slice(i_start, i)  # i_start:i slice 
        res[inds] = data[inds] - np.median(data[inds]) 

        i_start = i 

    # also fix last label 
    res[i_start:] = data[i_start:] - np.median(data[i_start:])

    return res

And here are some timings using IPython's %timeit magic:

>>> %timeit diffmedian_jit.py_func(data, index)  # non-jitted function
... %timeit diffmedian_jit(data, index)  # jitted function
...
4.27 ms ± 109 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
65.2 µs ± 1.01 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

Using the updated example data in the question these numbers (i.e. the runtime of the python function vs. the runtime of the JIT-accelerated functio) are

>>> %timeit diffmedian_jit.py_func(data, groups) 
... %timeit diffmedian_jit(data, groups)
2.45 s ± 34.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
93.6 ms ± 518 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

This amounts to a 65x speedup in the smaller case and a 26x speedup in the larger case (compared to slow loopy code, of course) using the accelerated code. Another upside is that (unlike typical vectorization with native numpy) we didn't need additional memory to achieve this speed, it's all about optimized and compiled low-level code that ends up being run.


The above function assumes that numpy int arrays are int64 by default, which is not actually the case on Windows. So an alternative is to remove the signature from the call to numba.njit, triggering proper just-in-time compilation. But this means that the function will be compiled during the first execution, which can meddle with timing results (we can either execute the function once manually, using representative data types, or just accept that the first timing execution will be much slower, which should be ignored). This is exactly what I tried to prevent by specifying a signature, which triggers ahead-of-time compilation.

Anyway, in the properly JIT case the decorator we need is just

@numba.njit
def diffmedian_jit(...):

Note that the above timings I showed for the jit-compiled function only apply once the function had been compiled. This either happens at definition (with eager compilation, when an explicit signature is passed to numba.njit), or during the first function call (with lazy compilation, when no signature is passed to numba.njit). If the function is only going to be executed once then the compile time should also be considered for the speed of this method. It's typically only worth compiling functions if the total time of compilation + execution is less than the uncompiled runtime (which is actually true in the above case, where the native python function is very slow). This mostly happens when you are calling your compiled function a lot of times.

As max9111 noted in a comment, one important feature of numba is the cache keyword to jit. Passing cache=True to numba.jit will store the compiled function to disk, so that during the next execution of the given python module the function will be loaded from there rather than recompiled, which again can spare you runtime in the long run.


Maybe you already did this, but if not, see if that's fast enough:

median_dict = {i: np.median(data[index == i]) for i in np.unique(index)}
def myFunc(my_dict, a): 
    return my_dict[a]
vect_func = np.vectorize(myFunc)
median_diff = data - vect_func(median_dict, index)
median_diff

Output:

array([-0.025,  0.025,  0.05 , -0.05 , -0.19 ,  0.29 ,  0.   ,  0.1  ,
   -0.1  ])

One approach would be to use Pandas here purely to make use of groupby. I've inflated the input sizes a bit to give a better understanding of the timings (since there is overhead in creating the DF).

import numpy as np
import pandas as pd

data =  [1.00, 1.05, 1.30, 1.20, 1.06, 1.54, 1.33, 1.87, 1.67]
index = [0,    0,    1,    1,    1,    1,    2,    3,    3]

data = data * 500
index = np.sort(np.random.randint(0, 30, 4500))

def df_approach(data, index):
    df = pd.DataFrame({'data': data, 'label': index})
    df['median'] = df.groupby('label')['data'].transform('median')
    df['result'] = df['data'] - df['median']

Gives the following timeit:

%timeit df_approach(data, index)
5.38 ms ± 50.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

For the same sample size, I get the dict approach of Aryerez to be:

%timeit dict_approach(data, index)
8.12 ms ± 3.47 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

However, if we increase the inputs by another factor of 10, the timings become:

%timeit df_approach(data, index)
7.72 ms ± 85 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

%timeit dict_approach(data, index)
30.2 ms ± 10.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

However, at the expense of some reability, the answer by Divakar using pure numpy comes in at:

%timeit bin_median_subtract(data, index)
573 µs ± 7.48 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

In light of the new dataset (which really should have been set at the start):

%timeit df_approach(data, groups)
472 ms ± 2.52 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

%timeit bin_median_subtract(data, groups) #https://stackoverflow.com/a/58788623/4799172
3.02 s ± 31.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

%timeit dict_approach(data, groups) #https://stackoverflow.com/a/58788199/4799172
<I gave up after 1 minute>

# jitted (using @numba.njit('f8[:](f8[:], i4[:]') on Windows) from  https://stackoverflow.com/a/58788635/4799172
%timeit diffmedian_jit(data, groups)
132 ms ± 3.12 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)