Find prime numbers using Scala. Help me to improve

Here's a functional implementation of the Sieve of Eratosthenes, as presented in Odersky's "Functional Programming Principles in Scala" Coursera course :

  // Sieving integral numbers
  def sieve(s: Stream[Int]): Stream[Int] = {
    s.head #:: sieve(s.tail.filter(_ % s.head != 0))
  }

  // All primes as a lazy sequence
  val primes = sieve(Stream.from(2))

  // Dumping the first five primes
  print(primes.take(5).toList) // List(2, 3, 5, 7, 11)

The style looks fine to me. Although the Sieve of Eratosthenes is a very efficient way to find prime numbers, your approach works well too, since you are only testing for division against known primes. You need to watch out however--your recursive function is not tail recursive. A tail recursive function does not modify the result of the recursive call--in your example you prepend to the result of the recursive call. This means that you will have a long call stack and so findPrime will not work for large i. Here is a tail-recursive solution.

def primesUnder(n: Int): List[Int] = {
  require(n >= 2)

  def rec(i: Int, primes: List[Int]): List[Int] = {
    if (i >= n) primes
    else if (prime(i, primes)) rec(i + 1, i :: primes)
    else rec(i + 1, primes)
  }

  rec(2, List()).reverse
}

def prime(num: Int, factors: List[Int]): Boolean = factors.forall(num % _ != 0)

This solution isn't prettier--it's more of a detail to get your solution to work for large arguments. Since the list is built up backwards to take advantage of fast prepends, the list needs to be reversed. As an alternative, you could use an Array, Vector or a ListBuffer to append the results. With the Array, however, you would need to estimate how much memory to allocate for it. Fortunately we know that pi(n) is about equal to n / ln(n) so you can choose a reasonable size. Array and ListBuffer are also a mutable data types, which goes again your desire for functional style.

Update: to get good performance out of the Sieve of Eratosthenes I think you'll need to store data in a native array, which also goes against your desire for style in functional programming. There might be a creative functional implementation though!

Update: oops! Missed it! This approach works well too if you only divide by primes less than the square root of the number you are testing! I missed this, and unfortunately it's not easy to adjust my solution to do this because I'm storing the primes backwards.

Update: here's a very non-functional solution that at least only checks up to the square root.

rnative, you could use an Array, Vector or a ListBuffer to append the results. With the Array, however, you would need to estimate how much memory to allocate for it. Fortunately we know that pi(n) is about equal to n / ln(n) so you can choose a reasonable size. Array and ListBuffer are also a mutable data types, which goes again your desire for functional style.

Update: to get good performance out of the Sieve of Eratosthenes I think you'll need to store data in a native array, which also goes against your desire for style in functional programming. There might be a creative functional implementation though!

Update: oops! Missed it! This approach works well too if you only divide by primes less than the square root of the number you are testing! I missed this, and unfortunately it's not easy to adjust my solution to do this because I'm storing the primes backwards.

Update: here's a very non-functional solution that at least only checks up to the square root.

import scala.collection.mutable.ListBuffer

def primesUnder(n: Int): List[Int] = {
  require(n >= 2)

  val primes = ListBuffer(2)
  for (i <- 3 to n) {
    if (prime(i, primes.iterator)) {
      primes += i
    }
  }

  primes.toList
}

// factors must be in sorted order
def prime(num: Int, factors: Iterator[Int]): Boolean = 
  factors.takeWhile(_ <= math.sqrt(num).toInt) forall(num % _ != 0)

Or I could use Vectors with my original approach. Vectors are probably not the best solution because they don't have the fasted O(1) even though it's amortized O(1).


As schmmd mentions, you want it to be tail recursive, and you also want it to be lazy. Fortunately there is a perfect data-structure for this: Stream.

This is a very efficient prime calculator implemented as a Stream, with a few optimisations:

object Prime {
  def is(i: Long): Boolean =
    if (i == 2) true
    else if ((i & 1) == 0) false // efficient div by 2
    else prime(i)

  def primes: Stream[Long] = 2 #:: prime3

  private val prime3: Stream[Long] = {
    @annotation.tailrec
    def nextPrime(i: Long): Long =
      if (prime(i)) i else nextPrime(i + 2) // tail

    def next(i: Long): Stream[Long] =
      i #:: next(nextPrime(i + 2))

    3 #:: next(5)

  }

    // assumes not even, check evenness before calling - perf note: must pass partially applied >= method
  def prime(i: Long): Boolean =
    prime3 takeWhile (math.sqrt(i).>= _) forall { i % _ != 0 }
}

Prime.is is the prime check predicate, and Prime.primes returns a Stream of all prime numbers. prime3 is where the Stream is computed, using the prime predicate to check for all prime divisors less than the square root of i.