Java 8: streams and the Sieve of Eratosthenes

Sure, it is possible, but greatly complicated by the fact that Java streams have no simple way of being decomposed into their head and their tail (you can easily get either one of these, but not both since the stream will already have been consumed by then - sounds like someone could use linear types...).

The solution, is to keep a mutable variable around. For instance, that mutable variable can be the predicate that tests whether a number is a multiple of any other number seen so far.

import java.util.stream.*;
import java.util.function.IntPredicate;

public class Primes {

   static IntPredicate isPrime = x -> true;
   static IntStream primes = IntStream
                               .iterate(2, i -> i + 1)
                               .filter(i -> isPrime.test(i))
                               .peek(i -> isPrime = isPrime.and(v -> v % i != 0));

   public static void main(String[] args) {
      // Print out the first 10 primes.
      primes.limit(10)
            .forEach(p -> System.out.println(p));

   }
}

Then, you get the expected result:

$ javac Primes.java
$ java Primes
2
3
5
7
11
13
17
19
23
29

If you'd accept a Scala solution instead, here it is:

def sieve(nums:Stream[Int]):Stream[Int] = nums.head #:: sieve(nums.filter{_ % nums.head > 0})
val primes:Stream[Int] = sieve(Stream.from(2))

It is not as elegant as the Haskell solution but it comes pretty close IMO. Here is the output:

scala> primes take 10 foreach println
2
3
5
7
11
13
17
19
23
29

Scala's Stream is a lazy list which is far lazier than the Java 8 Stream. In the documentation you can even find the example Fibonacci sequence implemantation which corresponds to the canonical Haskell zipWith implementation.


EDIT: The sieve, unoptimised, returning an infinite stream of primes

public static Stream<Integer> primeStreamEra() {
    final HashMap<Integer, Integer> seedsFactors =
        new HashMap<Integer, Integer>();
    return IntStream.iterate(1, i -> i + 1)
                    .filter(i -> {
                        final int currentNum = i;
                        seedsFactors.entrySet().parallelStream()
                            .forEach(e -> {
                                // Update all factors until they have
                                //the closest value that is >= currentNum
                                while(e.getValue() < currentNum)
                                    e.setValue(e.getValue() + e.getKey());
                            });
                        if(!seedsFactors.containsValue(i)) {
                            if(i != 1)
                                seedsFactors.put(i, i);
                            return true;
                        }
                        return false;
                    }).boxed();
}

Test:

public static void main(String[] args) {
    primeStreamEra().forEach(i -> System.out.println(i));
}

Initial Post:

A somewhat simpler solution that avoids some unnecessary operations (such as testing even numbers).

We iterate all odd numbers from 3 until the limit.

Within the filter function:

  • We test for all primes we have found that are smaller/equal than sqrt(currentNumber) rounded down.
  • If they divide our current number return false.
  • Else add to the list of found primes and return true.

Function:

public static IntStream primeStream(final int limit) {
    final ArrayList<Integer> primes = new ArrayList<Integer>();
    IntStream primesThreeToLimit =  
           IntStream.iterate(3, i -> i + 2)
                    .takeWhile(i -> i <= limit)
                    .filter(i -> {
                        final int testUntil = (int) Math.sqrt((double) limit);
                        for(Integer p: primes) {
                            if(i % p == 0) return false;
                            if(p > testUntil) break;
                        }
                        primes.add(i);
                        return true;
                    });
    return IntStream.concat(IntStream.of(1,2), primesThreeToLimit);
}

Test:

public static void main(String[] args) {
    System.out.println(Arrays.toString(primeStream(50).toArray()));
}

Output: [1, 2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47]

Edit: To convert from IntStream to Stream<Integer> just do primeStream(50).boxed().