Java 8u40 Math.round() very slow

Casual benchmarking: you benchmark A, but actually measure B, and conclude you've measured C.

Modern JVMs are too complex, and do all kinds of optimization. If you try to measure some small piece of code, it is really complicated to do it correctly without very, very detailed knowledge of what the JVM is doing. The culprit of many benchmarks is the dead-code elimination: compilers are smart enough to deduce some computations are redundant, and eliminate them completely. Please read the following slides http://shipilev.net/talks/jvmls-July2014-benchmarking.pdf. In order to "fix" Adam's microbenchmark(I still can't understand what it is measuring and this "fix" does not take into account warm up, OSR and many others microbenchmarking pitfalls) we have to print the result of the computation to system output:

    int result = 0;
    long t0 = System.currentTimeMillis();
    for (int i = 0; i < 1e9; i++) {
        result += Math.round((float) i / (float) (i + 1));
    }
    long t1 = System.currentTimeMillis();
    System.out.println("result = " + result);
    System.out.println(String.format("%s, Math.round(float), %.1f ms", System.getProperty("java.version"), (t1 - t0)/1f));

As a result:

result = 999999999
1.8.0_25, Math.round(float), 5251.0 ms

result = 999999999
1.8.0_40, Math.round(float), 3903.0 ms

The same "fix" for original MVCE example

It took 401772 milliseconds to complete edu.jvm.runtime.RoundFloatToInt. <==== 1.8.0_40

It took 410767 milliseconds to complete edu.jvm.runtime.RoundFloatToInt. <==== 1.8.0_25

If you want to measure the real cost of Math#round you should write something like this(based on jmh)

package org.openjdk.jmh.samples;

import org.openjdk.jmh.annotations.*;
import org.openjdk.jmh.runner.Runner;
import org.openjdk.jmh.runner.RunnerException;
import org.openjdk.jmh.runner.options.Options;
import org.openjdk.jmh.runner.options.OptionsBuilder;
import org.openjdk.jmh.runner.options.VerboseMode;

import java.util.Random;
import java.util.concurrent.TimeUnit;

@State(Scope.Benchmark)
@BenchmarkMode(Mode.AverageTime)
@OutputTimeUnit(TimeUnit.NANOSECONDS)
@Warmup(iterations = 3, time = 5, timeUnit = TimeUnit.SECONDS)
@Measurement(iterations = 3, time = 5, timeUnit = TimeUnit.SECONDS)
public class RoundBench {

    float[] floats;
    int i;

    @Setup
    public void initI() {
        Random random = new Random(0xDEAD_BEEF);
        floats = new float[8096];
        for (int i = 0; i < floats.length; i++) {
            floats[i] = random.nextFloat();
        }
    }

    @Benchmark
    public float baseline() {
        i++;
        i = i & 0xFFFFFF00;
        return floats[i];
    }

    @Benchmark
    public int round() {
        i++;
        i = i & 0xFFFFFF00;
        return Math.round(floats[i]);
    }

    public static void main(String[] args) throws RunnerException {
        Options options = new OptionsBuilder()
                .include(RoundBench.class.getName())
                .build();
        new Runner(options).run();
    }
}

My results are:

1.8.0_25
Benchmark            Mode  Cnt  Score   Error  Units
RoundBench.baseline  avgt    6  2.565 ± 0.028  ns/op
RoundBench.round     avgt    6  4.459 ± 0.065  ns/op

1.8.0_40 
Benchmark            Mode  Cnt  Score   Error  Units
RoundBench.baseline  avgt    6  2.589 ± 0.045  ns/op
RoundBench.round     avgt    6  4.588 ± 0.182  ns/op

In order to find the root cause of the problem you can use https://github.com/AdoptOpenJDK/jitwatch/. To save time, I can say that the size of JITted code for Math#round was increased in 8.0_40. It is almost unnoticeable for small methods, but in case of huge methods too long sheet of machine code pollutes instruction cache.


MVCE based on OP

  • can likely be simplified further
  • changed int3 = statements to int3 += to reduce chance of dead code removal. int3 = difference from 8u31 to 8u40 is factor 3x slower. Using int3 += difference is only 15% slower.
  • print result to further reduce chance of dead code removal optimisations

Code

public class MathTime {
    static float[][] float1 = new float[8][16];
    static float[][] float2 = new float[8][16];

    public static void main(String[] args) {
        for (int j = 0; j < 8; j++) {
            for (int k = 0; k < 16; k++) {
                float1[j][k] = (float) (j + k);
                float2[j][k] = (float) (j + k);
            }
        }
        new Test().run();
    }

    private static class Test {
        int int3;

        public void run() {
            for (String test : new String[] { "warmup", "real" }) {

                long t0 = System.nanoTime();

                for (int count = 0; count < 1e7; count++) {
                    int i = count % 8;
                    int3 += Math.round(float1[i][0] + float2[i][0]);
                    int3 += Math.round(float1[i][1] + float2[i][1]);
                    int3 += Math.round(float1[i][2] + float2[i][2]);
                    int3 += Math.round(float1[i][3] + float2[i][3]);
                    int3 += Math.round(float1[i][4] + float2[i][4]);
                    int3 += Math.round(float1[i][5] + float2[i][5]);
                    int3 += Math.round(float1[i][6] + float2[i][6]);
                    int3 += Math.round(float1[i][7] + float2[i][7]);
                    int3 += Math.round(float1[i][8] + float2[i][8]);
                    int3 += Math.round(float1[i][9] + float2[i][9]);
                    int3 += Math.round(float1[i][10] + float2[i][10]);
                    int3 += Math.round(float1[i][11] + float2[i][11]);
                    int3 += Math.round(float1[i][12] + float2[i][12]);
                    int3 += Math.round(float1[i][13] + float2[i][13]);
                    int3 += Math.round(float1[i][14] + float2[i][14]);
                    int3 += Math.round(float1[i][15] + float2[i][15]);

                    int3 += Math.round(float1[i][0] * float2[i][0]);
                    int3 += Math.round(float1[i][1] * float2[i][1]);
                    int3 += Math.round(float1[i][2] * float2[i][2]);
                    int3 += Math.round(float1[i][3] * float2[i][3]);
                    int3 += Math.round(float1[i][4] * float2[i][4]);
                    int3 += Math.round(float1[i][5] * float2[i][5]);
                    int3 += Math.round(float1[i][6] * float2[i][6]);
                    int3 += Math.round(float1[i][7] * float2[i][7]);
                    int3 += Math.round(float1[i][8] * float2[i][8]);
                    int3 += Math.round(float1[i][9] * float2[i][9]);
                    int3 += Math.round(float1[i][10] * float2[i][10]);
                    int3 += Math.round(float1[i][11] * float2[i][11]);
                    int3 += Math.round(float1[i][12] * float2[i][12]);
                    int3 += Math.round(float1[i][13] * float2[i][13]);
                    int3 += Math.round(float1[i][14] * float2[i][14]);
                    int3 += Math.round(float1[i][15] * float2[i][15]);

                    int3 += Math.round(float1[i][0] / float2[i][0]);
                    int3 += Math.round(float1[i][1] / float2[i][1]);
                    int3 += Math.round(float1[i][2] / float2[i][2]);
                    int3 += Math.round(float1[i][3] / float2[i][3]);
                    int3 += Math.round(float1[i][4] / float2[i][4]);
                    int3 += Math.round(float1[i][5] / float2[i][5]);
                    int3 += Math.round(float1[i][6] / float2[i][6]);
                    int3 += Math.round(float1[i][7] / float2[i][7]);
                    int3 += Math.round(float1[i][8] / float2[i][8]);
                    int3 += Math.round(float1[i][9] / float2[i][9]);
                    int3 += Math.round(float1[i][10] / float2[i][10]);
                    int3 += Math.round(float1[i][11] / float2[i][11]);
                    int3 += Math.round(float1[i][12] / float2[i][12]);
                    int3 += Math.round(float1[i][13] / float2[i][13]);
                    int3 += Math.round(float1[i][14] / float2[i][14]);
                    int3 += Math.round(float1[i][15] / float2[i][15]);

                }
                long t1 = System.nanoTime();
                System.out.println(int3);
                System.out.println(String.format("%s, Math.round(float), %s, %.1f ms", System.getProperty("java.version"), test, (t1 - t0) / 1e6));
            }
        }
    }
}

Results

adam@brimstone:~$ ./jdk1.8.0_40/bin/javac MathTime.java;./jdk1.8.0_40/bin/java -cp . MathTime 
1.8.0_40, Math.round(float), warmup, 6846.4 ms
1.8.0_40, Math.round(float), real, 6058.6 ms
adam@brimstone:~$ ./jdk1.8.0_31/bin/javac MathTime.java;./jdk1.8.0_31/bin/java -cp . MathTime 
1.8.0_31, Math.round(float), warmup, 5717.9 ms
1.8.0_31, Math.round(float), real, 5282.7 ms
adam@brimstone:~$ ./jdk1.8.0_25/bin/javac MathTime.java;./jdk1.8.0_25/bin/java -cp . MathTime 
1.8.0_25, Math.round(float), warmup, 5702.4 ms
1.8.0_25, Math.round(float), real, 5262.2 ms

Observations

  • For trivial uses of Math.round(float) I can find no difference in performance on my platform (Linux x86_64). There is only a difference in benchmark, my previous naive and incorrect benchmarks only exposed differences in behaviour in optimisation as Ivan's answer and Marco13's comments point out.
  • 8u40 is less aggressive in dead code elimination than previous versions, meaning more code is executed in some corner cases and hence slower.
  • 8u40 takes slightly longer to warm up, but once "there", quicker.

Source analysis

Surprisingly Math.round(float) is a pure Java implementation rather than native, the code for both 8u31 and 8u40 is identical.

diff  jdk1.8.0_31/src/java/lang/Math.java jdk1.8.0_40/src/java/lang/Math.java
-no differences-

public static int round(float a) {
    int intBits = Float.floatToRawIntBits(a);
    int biasedExp = (intBits & FloatConsts.EXP_BIT_MASK)
            >> (FloatConsts.SIGNIFICAND_WIDTH - 1);
    int shift = (FloatConsts.SIGNIFICAND_WIDTH - 2
            + FloatConsts.EXP_BIAS) - biasedExp;
    if ((shift & -32) == 0) { // shift >= 0 && shift < 32
        // a is a finite number such that pow(2,-32) <= ulp(a) < 1
        int r = ((intBits & FloatConsts.SIGNIF_BIT_MASK)
                | (FloatConsts.SIGNIF_BIT_MASK + 1));
        if (intBits < 0) {
            r = -r;
        }
        // In the comments below each Java expression evaluates to the value
        // the corresponding mathematical expression:
        // (r) evaluates to a / ulp(a)
        // (r >> shift) evaluates to floor(a * 2)
        // ((r >> shift) + 1) evaluates to floor((a + 1/2) * 2)
        // (((r >> shift) + 1) >> 1) evaluates to floor(a + 1/2)
        return ((r >> shift) + 1) >> 1;
    } else {
        // a is either
        // - a finite number with abs(a) < exp(2,FloatConsts.SIGNIFICAND_WIDTH-32) < 1/2
        // - a finite number with ulp(a) >= 1 and hence a is a mathematical integer
        // - an infinity or NaN
        return (int) a;
    }
}

Tags:

Java

Jvm

Jit