What is the difference between map and flatMap and a good use case for each?

It boils down to your initial question: what you mean by flattening ?

When you use flatMap, a "multi-dimensional" collection becomes "one-dimensional" collection.

val array1d = Array ("1,2,3", "4,5,6", "7,8,9")  
//array1d is an array of strings

val array2d = array1d.map(x => x.split(","))
//array2d will be : Array( Array(1,2,3), Array(4,5,6), Array(7,8,9) )

val flatArray = array1d.flatMap(x => x.split(","))
//flatArray will be : Array (1,2,3,4,5,6,7,8,9)

You want to use a flatMap when,

  • your map function results in creating multi layered structures
  • but all you want is a simple - flat - one dimensional structure, by removing ALL the internal groupings

all examples are good....Here is nice visual illustration... source courtesy : DataFlair training of spark

Map : A map is a transformation operation in Apache Spark. It applies to each element of RDD and it returns the result as new RDD. In the Map, operation developer can define his own custom business logic. The same logic will be applied to all the elements of RDD.

Spark RDD map function takes one element as input process it according to custom code (specified by the developer) and returns one element at a time. Map transforms an RDD of length N into another RDD of length N. The input and output RDDs will typically have the same number of records.

enter image description here

Example of map using scala :

val x = spark.sparkContext.parallelize(List("spark", "map", "example",  "sample", "example"), 3)
val y = x.map(x => (x, 1))
y.collect
// res0: Array[(String, Int)] = 
//    Array((spark,1), (map,1), (example,1), (sample,1), (example,1))

// rdd y can be re writen with shorter syntax in scala as 
val y = x.map((_, 1))
y.collect
// res1: Array[(String, Int)] = 
//    Array((spark,1), (map,1), (example,1), (sample,1), (example,1))

// Another example of making tuple with string and it's length
val y = x.map(x => (x, x.length))
y.collect
// res3: Array[(String, Int)] = 
//    Array((spark,5), (map,3), (example,7), (sample,6), (example,7))

FlatMap :

A flatMap is a transformation operation. It applies to each element of RDD and it returns the result as new RDD. It is similar to Map, but FlatMap allows returning 0, 1 or more elements from map function. In the FlatMap operation, a developer can define his own custom business logic. The same logic will be applied to all the elements of the RDD.

What does "flatten the results" mean?

A FlatMap function takes one element as input process it according to custom code (specified by the developer) and returns 0 or more element at a time. flatMap() transforms an RDD of length N into another RDD of length M.

enter image description here

Example of flatMap using scala :

val x = spark.sparkContext.parallelize(List("spark flatmap example",  "sample example"), 2)

// map operation will return Array of Arrays in following case : check type of res0
val y = x.map(x => x.split(" ")) // split(" ") returns an array of words
y.collect
// res0: Array[Array[String]] = 
//  Array(Array(spark, flatmap, example), Array(sample, example))

// flatMap operation will return Array of words in following case : Check type of res1
val y = x.flatMap(x => x.split(" "))
y.collect
//res1: Array[String] = 
//  Array(spark, flatmap, example, sample, example)

// RDD y can be re written with shorter syntax in scala as 
val y = x.flatMap(_.split(" "))
y.collect
//res2: Array[String] = 
//  Array(spark, flatmap, example, sample, example)

Generally we use word count example in hadoop. I will take the same use case and will use map and flatMap and we will see the difference how it is processing the data.

Below is the sample data file.

hadoop is fast
hive is sql on hdfs
spark is superfast
spark is awesome

The above file will be parsed using map and flatMap.

Using map

>>> wc = data.map(lambda line:line.split(" "));
>>> wc.collect()
[u'hadoop is fast', u'hive is sql on hdfs', u'spark is superfast', u'spark is awesome']

Input has 4 lines and output size is 4 as well, i.e., N elements ==> N elements.

Using flatMap

>>> fm = data.flatMap(lambda line:line.split(" "));
>>> fm.collect()
[u'hadoop', u'is', u'fast', u'hive', u'is', u'sql', u'on', u'hdfs', u'spark', u'is', u'superfast', u'spark', u'is', u'awesome']

The output is different from map.


Let's assign 1 as value for each key to get the word count.

  • fm: RDD created by using flatMap
  • wc: RDD created using map
>>> fm.map(lambda word : (word,1)).collect()
[(u'hadoop', 1), (u'is', 1), (u'fast', 1), (u'hive', 1), (u'is', 1), (u'sql', 1), (u'on', 1), (u'hdfs', 1), (u'spark', 1), (u'is', 1), (u'superfast', 1), (u'spark', 1), (u'is', 1), (u'awesome', 1)]

Whereas flatMap on RDD wc will give the below undesired output:

>>> wc.flatMap(lambda word : (word,1)).collect()
[[u'hadoop', u'is', u'fast'], 1, [u'hive', u'is', u'sql', u'on', u'hdfs'], 1, [u'spark', u'is', u'superfast'], 1, [u'spark', u'is', u'awesome'], 1]

You can't get the word count if map is used instead of flatMap.

As per the definition, difference between map and flatMap is:

map: It returns a new RDD by applying given function to each element of the RDD. Function in map returns only one item.

flatMap: Similar to map, it returns a new RDD by applying a function to each element of the RDD, but output is flattened.


Here is an example of the difference, as a spark-shell session:

First, some data - two lines of text:

val rdd = sc.parallelize(Seq("Roses are red", "Violets are blue"))  // lines

rdd.collect

    res0: Array[String] = Array("Roses are red", "Violets are blue")

Now, map transforms an RDD of length N into another RDD of length N.

For example, it maps from two lines into two line-lengths:

rdd.map(_.length).collect

    res1: Array[Int] = Array(13, 16)

But flatMap (loosely speaking) transforms an RDD of length N into a collection of N collections, then flattens these into a single RDD of results.

rdd.flatMap(_.split(" ")).collect

    res2: Array[String] = Array("Roses", "are", "red", "Violets", "are", "blue")

We have multiple words per line, and multiple lines, but we end up with a single output array of words

Just to illustrate that, flatMapping from a collection of lines to a collection of words looks like:

["aa bb cc", "", "dd"] => [["aa","bb","cc"],[],["dd"]] => ["aa","bb","cc","dd"]

The input and output RDDs will therefore typically be of different sizes for flatMap.

If we had tried to use map with our split function, we'd have ended up with nested structures (an RDD of arrays of words, with type RDD[Array[String]]) because we have to have exactly one result per input:

rdd.map(_.split(" ")).collect

    res3: Array[Array[String]] = Array(
                                     Array(Roses, are, red), 
                                     Array(Violets, are, blue)
                                 )

Finally, one useful special case is mapping with a function which might not return an answer, and so returns an Option. We can use flatMap to filter out the elements that return None and extract the values from those that return a Some:

val rdd = sc.parallelize(Seq(1,2,3,4))

def myfn(x: Int): Option[Int] = if (x <= 2) Some(x * 10) else None

rdd.flatMap(myfn).collect

    res3: Array[Int] = Array(10,20)

(noting here that an Option behaves rather like a list that has either one element, or zero elements)

Tags:

Apache Spark