Apply a custom function to a spark dataframe group

What you are looking for exists since Spark 2.3: Pandas vectorized UDFs. It allows to group a DataFrame and apply custom transformations with pandas, distributed on each group:

df.groupBy("groupColumn").apply(myCustomPandasTransformation)

It is very easy to use so I will just put a link to Databricks' presentation of pandas UDF.

However, I don't know such a practical way to make grouped transformations in Scala yet, so any additional advice is welcome.

EDIT: in Scala, you can achieve the same thing since earlier versions of Spark, using Dataset's groupByKey + mapGroups/flatMapGroups.


  • While Spark provides some ways to integrate with Pandas it doesn't make Pandas distributed. So whatever you do with Pandas in Spark is simply local (either to driver or executor when used inside transformations) operation.

    If you're looking for a distributed system with Pandas-like API you should take a look at dask.

  • You can define User Defined Aggregate functions or Aggregators to process grouped Datasets but this part of the API is directly accessible only in Scala. It is not that hard to write a Python wrapper when you create one.
  • RDD API provides a number of functions which can be used to perform operations in groups starting with low level repartition / repartitionAndSortWithinPartitions and ending with a number of *byKey methods (combineByKey, groupByKey, reduceByKey, etc.).

    Which one is applicable in your case depends on the properties of the function you want to apply (is it associative and commutative, can it work on streams, does it expect specific order).

    The most general but inefficient approach can be summarized as follows:

    h(rdd.keyBy(f).groupByKey().mapValues(g).collect())
    

    where f maps from value to key, g corresponds to per-group aggregation and h is a final merge. Most of the time you can do much better than that so it should be used only as the last resort.

  • Relatively complex logic can be expressed using DataFrames / Spark SQL and window functions.

  • See also Applying UDFs on GroupedData in PySpark (with functioning python example)