Reduce a key-value pair into a key-list pair with Apache Spark

tl;dr If you really require operation like this use groupByKey as suggested by @MariusIon. Every other solution proposed here is either bluntly inefficient are at least suboptimal compared to direct grouping.

reduceByKey with list concatenation is not an acceptable solution because:

  • Requires initialization of O(N) lists.
  • Each application of + to a pair of lists requires full copy of both lists (O(N)) effectively increasing overall complexity to O(N2).
  • Doesn't address any of the problems introduced by groupByKey. Amount of data that has to be shuffled as well as the size of the final structure are the same.
  • Unlike suggested by one of the answers there is no difference in a level of parallelism between implementation using reduceByKey and groupByKey.

combineByKey with list.extend is a suboptimal solution because:

  • Creates O(N) list objects in MergeValue (this could be optimized by using list.append directly on the new item).
  • If optimized with list.append it is exactly equivalent to an old (Spark <= 1.3) implementation of a groupByKey and ignores all the optimizations introduced by SPARK-3074 which enables external (on-disk) grouping of the larger-than-memory structures.

Map and ReduceByKey

Input type and output type of reduce must be the same, therefore if you want to aggregate a list, you have to map the input to lists. Afterwards you combine the lists into one list.

Combining lists

You'll need a method to combine lists into one list. Python provides some methods to combine lists.

append modifies the first list and will always return None.

x = [1, 2, 3]
x.append([4, 5])
# x is [1, 2, 3, [4, 5]]

extend does the same, but unwraps lists:

x = [1, 2, 3]
x.extend([4, 5])
# x is [1, 2, 3, 4, 5]

Both methods return None, but you'll need a method that returns the combined list, therefore just use the plus sign.

x = [1, 2, 3] + [4, 5]
# x is [1, 2, 3, 4, 5]

Spark

file = spark.textFile("hdfs://...")
counts = file.flatMap(lambda line: line.split(" ")) \
         .map(lambda actor: (actor.split(",")[0], actor)) \ 

         # transform each value into a list
         .map(lambda nameTuple: (nameTuple[0], [ nameTuple[1] ])) \

         # combine lists: ([1,2,3] + [4,5]) becomes [1,2,3,4,5]
         .reduceByKey(lambda a, b: a + b)

CombineByKey

It's also possible to solve this with combineByKey, which is used internally to implement reduceByKey, but it's more complex and "using one of the specialized per-key combiners in Spark can be much faster". Your use case is simple enough for the upper solution.

GroupByKey

It's also possible to solve this with groupByKey, but it reduces parallelization and therefore could be much slower for big data sets.


I'm kind of late to the conversation, but here's my suggestion:

>>> foo = sc.parallelize([(1, ('a','b')), (2, ('c','d')), (1, ('x','y'))])
>>> foo.map(lambda (x,y): (x, [y])).reduceByKey(lambda p,q: p+q).collect()
[(1, [('a', 'b'), ('x', 'y')]), (2, [('c', 'd')])]