Fastest Implementation of the Natural Exponential Function Using SSE

A good increase in accuracy in my algorithm (implementation FastExpSse in the answer above) can be obtained at the cost of an integer subtraction and floating-point division by using FastExpSse(x/2)/FastExpSse(-x/2) instead of FastExpSse(x). The trick here is to set the shift parameter (298765 above) to zero so that the piecewise linear approximations in the numerator and denominator line up to give you substantial error cancellation. Roll it into a single function:

__m128 BetterFastExpSse (__m128 x)
{
  const __m128 a = _mm_set1_ps ((1 << 22) / float(M_LN2));  // to get exp(x/2)
  const __m128i b = _mm_set1_epi32 (127 * (1 << 23));       // NB: zero shift!
  __m128i r = _mm_cvtps_epi32 (_mm_mul_ps (a, x));
  __m128i s = _mm_add_epi32 (b, r);
  __m128i t = _mm_sub_epi32 (b, r);
  return _mm_div_ps (_mm_castsi128_ps (s), _mm_castsi128_ps (t));
}

(I'm not a hardware guy - how bad a performance killer is the division here?)

If you need exp(x) just to get y = tanh(x) (e.g. for neural networks), use FastExpSse with zero shift as follows:

a = FastExpSse(x);
b = FastExpSse(-x);
y = (a - b)/(a + b);

to get the same type of error cancellation benefit. The logistic function works similarly, using FastExpSse(x/2)/(FastExpSse(x/2) + FastExpSse(-x/2)) with zero shift. (This is just to show the principle - you obviously don't want to evaluate FastExpSse multiple times here, but roll it into a single function along the lines of BetterFastExpSse above.)

I did develop a series of higher-order approximations from this, ever more accurate but also slower. Unpublished but happy to collaborate if anyone wants to give them a spin.

And finally, for some fun: use in reverse gear to get FastLogSse. Chaining that with FastExpSse gives you both operator and error cancellation, and out pops a blazingly fast power function...


The C code below is a translation into SSE intrinsics of an algorithm I used in a previous answer to a similar question.

The basic idea is to transform the computation of the standard exponential function into computation of a power of 2: expf (x) = exp2f (x / logf (2.0f)) = exp2f (x * 1.44269504). We split t = x * 1.44269504 into an integer i and a fraction f, such that t = i + f and 0 <= f <= 1. We can now compute 2f with a polynomial approximation, then scale the result by 2i by adding i to the exponent field of the single-precision floating-point result.

One problem that exists with an SSE implementation is that we want to compute i = floorf (t), but there is no fast way to compute the floor() function. However, we observe that for positive numbers, floor(x) == trunc(x), and that for negative numbers, floor(x) == trunc(x) - 1, except when x is a negative integer. However, since the core approximation can handle an f value of 1.0f, using the approximation for negative arguments is harmless. SSE provides an instruction to convert single-precision floating point operands to integers with truncation, so this solution is efficient.

Peter Cordes points out that SSE4.1 supports a fast floor function _mm_floor_ps(), so a variant using SSE4.1 is also shown below. Not all toolchains automatically predefine the macro __SSE4_1__ when SSE 4.1 code generation is enabled, but gcc does.

Compiler Explorer (Godbolt) shows that gcc 7.2 compiles the code below into sixteen instructions for plain SSE and twelve instructions for SSE 4.1.

#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <math.h>
#include <emmintrin.h>
#ifdef __SSE4_1__
#include <smmintrin.h>
#endif

/* max. rel. error = 1.72863156e-3 on [-87.33654, 88.72283] */
__m128 fast_exp_sse (__m128 x)
{
    __m128 t, f, e, p, r;
    __m128i i, j;
    __m128 l2e = _mm_set1_ps (1.442695041f);  /* log2(e) */
    __m128 c0  = _mm_set1_ps (0.3371894346f);
    __m128 c1  = _mm_set1_ps (0.657636276f);
    __m128 c2  = _mm_set1_ps (1.00172476f);

    /* exp(x) = 2^i * 2^f; i = floor (log2(e) * x), 0 <= f <= 1 */   
    t = _mm_mul_ps (x, l2e);             /* t = log2(e) * x */
#ifdef __SSE4_1__
    e = _mm_floor_ps (t);                /* floor(t) */
    i = _mm_cvtps_epi32 (e);             /* (int)floor(t) */
#else /* __SSE4_1__*/
    i = _mm_cvttps_epi32 (t);            /* i = (int)t */
    j = _mm_srli_epi32 (_mm_castps_si128 (x), 31); /* signbit(t) */
    i = _mm_sub_epi32 (i, j);            /* (int)t - signbit(t) */
    e = _mm_cvtepi32_ps (i);             /* floor(t) ~= (int)t - signbit(t) */
#endif /* __SSE4_1__*/
    f = _mm_sub_ps (t, e);               /* f = t - floor(t) */
    p = c0;                              /* c0 */
    p = _mm_mul_ps (p, f);               /* c0 * f */
    p = _mm_add_ps (p, c1);              /* c0 * f + c1 */
    p = _mm_mul_ps (p, f);               /* (c0 * f + c1) * f */
    p = _mm_add_ps (p, c2);              /* p = (c0 * f + c1) * f + c2 ~= 2^f */
    j = _mm_slli_epi32 (i, 23);          /* i << 23 */
    r = _mm_castsi128_ps (_mm_add_epi32 (j, _mm_castps_si128 (p))); /* r = p * 2^i*/
    return r;
}

int main (void)
{
    union {
        float f[4];
        unsigned int i[4];
    } arg, res;
    double relerr, maxrelerr = 0.0;
    int i, j;
    __m128 x, y;

    float start[2] = {-0.0f, 0.0f};
    float finish[2] = {-87.33654f, 88.72283f};

    for (i = 0; i < 2; i++) {

        arg.f[0] = start[i];
        arg.i[1] = arg.i[0] + 1;
        arg.i[2] = arg.i[0] + 2;
        arg.i[3] = arg.i[0] + 3;
        do {
            memcpy (&x, &arg, sizeof(x));
            y = fast_exp_sse (x);
            memcpy (&res, &y, sizeof(y));
            for (j = 0; j < 4; j++) {
                double ref = exp ((double)arg.f[j]);
                relerr = fabs ((res.f[j] - ref) / ref);
                if (relerr > maxrelerr) {
                    printf ("arg=% 15.8e  res=%15.8e  ref=%15.8e  err=%15.8e\n", 
                            arg.f[j], res.f[j], ref, relerr);
                    maxrelerr = relerr;
                }
            }   
            arg.i[0] += 4;
            arg.i[1] += 4;
            arg.i[2] += 4;
            arg.i[3] += 4;
        } while (fabsf (arg.f[3]) < fabsf (finish[i]));
    }
    printf ("maximum relative errror = %15.8e\n", maxrelerr);
    return EXIT_SUCCESS;
}

An alternative design for fast_sse_exp() extracts the integer portion of the adjusted argument x / log(2) in round-to-nearest mode, using the well-known technique of adding the "magic" conversion constant 1.5 * 223 to force rounding in the correct bit position, then subtracting out the same number again. This requires that the SSE rounding mode in effect during the addition is "round to nearest or even", which is the default. wim pointed out in comments that some compilers may optimize out the addition and subtraction of the conversion constant cvt as redundant when aggressive optimization is used, interfering with the functionality of this code sequence, so it is recommended to inspect the machine code generated. The approximation interval for computation of 2f is now centered around zero, since -0.5 <= f <= 0.5, requiring a different core approximation.

/* max. rel. error <= 1.72860465e-3 on [-87.33654, 88.72283] */
__m128 fast_exp_sse (__m128 x)
{
    __m128 t, f, p, r;
    __m128i i, j;

    const __m128 l2e = _mm_set1_ps (1.442695041f); /* log2(e) */
    const __m128 cvt = _mm_set1_ps (12582912.0f);  /* 1.5 * (1 << 23) */
    const __m128 c0 =  _mm_set1_ps (0.238428936f);
    const __m128 c1 =  _mm_set1_ps (0.703448006f);
    const __m128 c2 =  _mm_set1_ps (1.000443142f);

    /* exp(x) = 2^i * 2^f; i = rint (log2(e) * x), -0.5 <= f <= 0.5 */
    t = _mm_mul_ps (x, l2e);             /* t = log2(e) * x */
    r = _mm_sub_ps (_mm_add_ps (t, cvt), cvt); /* r = rint (t) */
    f = _mm_sub_ps (t, r);               /* f = t - rint (t) */
    i = _mm_cvtps_epi32 (t);             /* i = (int)t */
    p = c0;                              /* c0 */
    p = _mm_mul_ps (p, f);               /* c0 * f */
    p = _mm_add_ps (p, c1);              /* c0 * f + c1 */
    p = _mm_mul_ps (p, f);               /* (c0 * f + c1) * f */
    p = _mm_add_ps (p, c2);              /* p = (c0 * f + c1) * f + c2 ~= exp2(f) */
    j = _mm_slli_epi32 (i, 23);          /* i << 23 */
    r = _mm_castsi128_ps (_mm_add_epi32 (j, _mm_castps_si128 (p))); /* r = p * 2^i*/
    return r;
}

The algorithm for the code in the question appears to be taken from the work of Nicol N. Schraudolph, which cleverly exploits the semi-logarithmic nature of IEEE-754 binary floating-point formats:

N. N. Schraudolph. "A fast, compact approximation of the exponential function." Neural Computation, 11(4), May 1999, pp.853-862.

After removal of the argument clamping code, it reduces to just three SSE instructions. The "magical" correction constant 486411 is not optimal for minimizing maximum relative error over the entire input domain. Based on simple binary search, the value 298765 seems to be superior, reducing maximum relative error for FastExpSse() to 3.56e-2 vs. maximum relative error of 1.73e-3 for fast_exp_sse().

/* max. rel. error = 3.55959567e-2 on [-87.33654, 88.72283] */
__m128 FastExpSse (__m128 x)
{
    __m128 a = _mm_set1_ps (12102203.0f); /* (1 << 23) / log(2) */
    __m128i b = _mm_set1_epi32 (127 * (1 << 23) - 298765);
    __m128i t = _mm_add_epi32 (_mm_cvtps_epi32 (_mm_mul_ps (a, x)), b);
    return _mm_castsi128_ps (t);
}

Schraudolph's algorithm basically uses the linear approximation 2f ~= 1.0 + f for f in [0,1], and its accuracy could be improved by adding a quadratic term. The clever part of Schraudolph's approach is computing 2i * 2f without explicitly separating the integer portion i = floor(x * 1.44269504) from the fraction. I see no way to extend that trick to a quadratic approximation, but one can certainly combine the floor() computation from Schraudolph with the quadratic approximation used above:

/* max. rel. error <= 1.72886892e-3 on [-87.33654, 88.72283] */
__m128 fast_exp_sse (__m128 x)
{
    __m128 f, p, r;
    __m128i t, j;
    const __m128 a = _mm_set1_ps (12102203.0f); /* (1 << 23) / log(2) */
    const __m128i m = _mm_set1_epi32 (0xff800000); /* mask for integer bits */
    const __m128 ttm23 = _mm_set1_ps (1.1920929e-7f); /* exp2(-23) */
    const __m128 c0 = _mm_set1_ps (0.3371894346f);
    const __m128 c1 = _mm_set1_ps (0.657636276f);
    const __m128 c2 = _mm_set1_ps (1.00172476f);

    t = _mm_cvtps_epi32 (_mm_mul_ps (a, x));
    j = _mm_and_si128 (t, m);            /* j = (int)(floor (x/log(2))) << 23 */
    t = _mm_sub_epi32 (t, j);
    f = _mm_mul_ps (ttm23, _mm_cvtepi32_ps (t)); /* f = (x/log(2)) - floor (x/log(2)) */
    p = c0;                              /* c0 */
    p = _mm_mul_ps (p, f);               /* c0 * f */
    p = _mm_add_ps (p, c1);              /* c0 * f + c1 */
    p = _mm_mul_ps (p, f);               /* (c0 * f + c1) * f */
    p = _mm_add_ps (p, c2);              /* p = (c0 * f + c1) * f + c2 ~= 2^f */
    r = _mm_castsi128_ps (_mm_add_epi32 (j, _mm_castps_si128 (p))); /* r = p * 2^i*/
    return r;
}