Explanation of the aggregate scala function

From the documentation:

def aggregate[B](z: ⇒ B)(seqop: (B, A) ⇒ B, combop: (B, B) ⇒ B): B

Aggregates the results of applying an operator to subsequent elements.

This is a more general form of fold and reduce. It has similar semantics, but does not require the result to be a supertype of the element type. It traverses the elements in different partitions sequentially, using seqop to update the result, and then applies combop to results from different partitions. The implementation of this operation may operate on an arbitrary number of collection partitions, so combop may be invoked an arbitrary number of times.

For example, one might want to process some elements and then produce a Set. In this case, seqop would process an element and append it to the list, while combop would concatenate two lists from different partitions together. The initial value z would be an empty set.

pc.aggregate(Set[Int]())(_ += process(_), _ ++ _)

Another example is calculating geometric mean from a collection of doubles (one would typically require big doubles for this). B the type of accumulated results z the initial value for the accumulated result of the partition - this will typically be the neutral element for the seqop operator (e.g. Nil for list concatenation or 0 for summation) and may be evaluated more than once seqop an operator used to accumulate results within a partition combop an associative operator used to combine results from different partitions

In your example B is a Tuple2[Int, Int]. The method seqop then takes a single element from the list, scoped as y, and updates the aggregate B to (x._1 + y, x._2 + 1). So it increments the second element in the tuple. This effectively puts the sum of elements into the first element of the tuple and the number of elements into the second element of the tuple.

The method combop then takes the results from each parallel execution thread and combines them. Combination by addition provides the same results as if it were run on the list sequentially.

Using B as a tuple is likely the confusing piece of this. You can break the problem down into two sub problems to get a better idea of what this is doing. res0 is the first element in the result tuple, and res1 is the second element in the result tuple.

// Sums all elements in parallel.
scala> x.par.aggregate(0)((x, y) => x + y, (x, y) => x + y)
res0: Int = 21

// Counts all elements in parallel.    
scala> x.par.aggregate(0)((x, y) => x + 1, (x, y) => x + y)
res1: Int = 6

First of all Thanks to Diego's reply which helped me connect the dots in understanding aggregate() function..

Let me confess that I couldn't sleep last night properly because I couldn't get how aggregate() works internally, I'll get good sleep tonight definitely :-)

Let's start understanding it

val result = List(1,2,3,4,5,6,7,8,9,10).par.aggregate((0, 0))
         (
          (x, y) => (x._1 + y, x._2 + 1), 
          (x,y) =>(x._1 + y._1, x._2 + y._2)
         )

result: (Int, Int) = (55,10)

aggregate function has 3 parts :

  1. initial value of accumulators : tuple(0,0) here
  2. seqop : It works like foldLeft with initial value of 0
  3. combop : It combines the result generated through parallelization (this part was difficult for me to understand)

Let's understand all 3 parts independently :

part-1 : Initial tuple (0,0)

Aggregate() starts with initial value of accumulators x which is (0,0) here. First tuple x._1 which is initially 0 is used to compute the sum, Second tuple x._2 is used to compute total number of elements in the list.

part-2 : (x, y) => (x._1 + y, x._2 + 1)

If you know how foldLeft works in scala then it should be easy to understand this part. Above function works just like foldLeft on our List(1,2,3,4...10).

Iteration#      (x._1 + y, x._2 + 1)
     1           (0+1, 0+1)
     2           (1+2, 1+1)
     3           (3+3, 2+1)
     4           (6+4, 3+1)
     .             ....
     .             ....
     10          (45+10, 9+1)

thus after all 10 iteration you'll get the result (55,10). If you understand this part the rest is very easy but for me it was the most difficult part in understanding if all the required computation are finished then what is the use of second part i.e. compop - stay tuned :-)

part 3 : (x,y) =>(x._1 + y._1, x._2 + y._2)

Well this 3rd part is combOp which combines the result generated by different threads during parallelization, remember we used 'par' in our code to enable parallel computation of list :

List(1,2,3,4,5,6,7,8,9,10).par.aggregate(....)

Apache spark is effectively using aggregate function to do parallel computation of RDD.

Let's assume that our List(1,2,3,4,5,6,7,8,9,10) is being computed by 3 threads in parallel. Here each thread is working on partial list and then our aggregate() combOp will combine the result of each thread's computation using the below code :

(x,y) =>(x._1 + y._1, x._2 + y._2)

Original list : List(1,2,3,4,5,6,7,8,9,10)

Thread1 start computing on partial list say (1,2,3,4), Thread2 computes (5,6,7,8) and Thread3 computes partial list say (9,10)

At the end of computation, Thread-1 result will be (10,4), Thread-2 result will be (26,4) and Thread-3 result will be (19,2).

At the end of parallel computation, we'll have ((10,4),(26,4),(19,2))

Iteration#      (x._1 + y._1, x._2 + y._2)
     1           (0+10, 0+4)
     2           (10+26, 4+4)
     3           (36+19, 8+2)

which is (55,10).

Finally let me re-iterate that seqOp job is to compute the sum of all the elements of list and total number of list whereas combine function's job is to combine different partial result generated during parallelization.

I hope above explanation help you understand the aggregate().


aggregate takes 3 parameters: a seed value, a computation function and a combination function.

What it does is basically split the collection in a number of threads, compute partial results using the computation function and then combine all these partial results using the combination function.

From what I can tell, your example function will return a pair (a, b) where a is the sum of the values in the list, b is the number of values in the list. Indeed, (21, 6).

How does this work? The seed value is the (0,0) pair. For an empty list, we have a sum of 0 and a number of items 0, so this is correct.

Your computation function takes an (Int, Int) pair x, which is your partial result, and a Int y, which is the next value in the list. This is your:

(x, y) => (x._1 + y, x._2 + 1)

Indeed, the result that we want is to increase the left element of x (the accumulator) by y, and the right element of x (the counter) by 1 for each y.

Your combination function takes an (Int, Int) pair x and an (Int, Int) pair y, which are your two partial results from different parallel computations, and combines them together as:

(x,y) => (x._1 + y._1, x._2 + y._2)

Indeed, we sum independently the left parts of the pairs and right parts of the pairs.

Your confusion comes from the fact that x and y in the first function ARE NOT the same x and y of the second function. In the first function, you have x of the type of the seed value, and y of the type of the collection elements, and you return a result of the type of x. In the second function, your two parameters are both of the same type of your seed value.

Hope it's clearer now!