How to return a "Tuple type" in a UDF in PySpark?

Stackoverflow keeps directing me to this question, so I guess I'll add some info here.

Returning simple types from UDF:

from pyspark.sql.types import *
from pyspark.sql import functions as F

def get_df():
  d = [(0.0, 0.0), (0.0, 3.0), (1.0, 6.0), (1.0, 9.0)]
  df = sqlContext.createDataFrame(d, ['x', 'y'])
  return df

df = get_df()
df.show()

# +---+---+
# |  x|  y|
# +---+---+
# |0.0|0.0|
# |0.0|3.0|
# |1.0|6.0|
# |1.0|9.0|
# +---+---+

func = udf(lambda x: str(x), StringType())
df = df.withColumn('y_str', func('y'))

func = udf(lambda x: int(x), IntegerType())
df = df.withColumn('y_int', func('y'))

df.show()

# +---+---+-----+-----+
# |  x|  y|y_str|y_int|
# +---+---+-----+-----+
# |0.0|0.0|  0.0|    0|
# |0.0|3.0|  3.0|    3|
# |1.0|6.0|  6.0|    6|
# |1.0|9.0|  9.0|    9|
# +---+---+-----+-----+

df.printSchema()

# root
#  |-- x: double (nullable = true)
#  |-- y: double (nullable = true)
#  |-- y_str: string (nullable = true)
#  |-- y_int: integer (nullable = true)

When integers are not enough:

df = get_df()

func = udf(lambda x: [0]*int(x), ArrayType(IntegerType()))
df = df.withColumn('list', func('y'))

func = udf(lambda x: {float(y): str(y) for y in range(int(x))}, 
           MapType(FloatType(), StringType()))
df = df.withColumn('map', func('y'))

df.show()
# +---+---+--------------------+--------------------+
# |  x|  y|                list|                 map|
# +---+---+--------------------+--------------------+
# |0.0|0.0|                  []|               Map()|
# |0.0|3.0|           [0, 0, 0]|Map(2.0 -> 2, 0.0...|
# |1.0|6.0|  [0, 0, 0, 0, 0, 0]|Map(0.0 -> 0, 5.0...|
# |1.0|9.0|[0, 0, 0, 0, 0, 0...|Map(0.0 -> 0, 5.0...|
# +---+---+--------------------+--------------------+

df.printSchema()
# root
#  |-- x: double (nullable = true)
#  |-- y: double (nullable = true)
#  |-- list: array (nullable = true)
#  |    |-- element: integer (containsNull = true)
#  |-- map: map (nullable = true)
#  |    |-- key: float
#  |    |-- value: string (valueContainsNull = true)

Returning complex datatypes from UDF:

df = get_df()
df = df.groupBy('x').agg(F.collect_list('y').alias('y[]'))
df.show()

# +---+----------+
# |  x|       y[]|
# +---+----------+
# |0.0|[0.0, 3.0]|
# |1.0|[9.0, 6.0]|
# +---+----------+

schema = StructType([
    StructField("min", FloatType(), True),
    StructField("size", IntegerType(), True),
    StructField("edges",  ArrayType(FloatType()), True),
    StructField("val_to_index",  MapType(FloatType(), IntegerType()), True)
    # StructField('insanity', StructType([StructField("min_", FloatType(), True), StructField("size_", IntegerType(), True)]))

])

def func(values):
  mn = min(values)
  size = len(values)
  lst = sorted(values)[::-1]
  val_to_index = {x: i for i, x in enumerate(values)}
  return (mn, size, lst, val_to_index)

func = udf(func, schema)
dff = df.select('*', func('y[]').alias('complex_type'))
dff.show(10, False)

# +---+----------+------------------------------------------------------+
# |x  |y[]       |complex_type                                          |
# +---+----------+------------------------------------------------------+
# |0.0|[0.0, 3.0]|[0.0,2,WrappedArray(3.0, 0.0),Map(0.0 -> 0, 3.0 -> 1)]|
# |1.0|[6.0, 9.0]|[6.0,2,WrappedArray(9.0, 6.0),Map(9.0 -> 1, 6.0 -> 0)]|
# +---+----------+------------------------------------------------------+

dff.printSchema()

# +---+----------+------------------------------------------------------+
# |x  |y[]       |complex_type                                          |
# +---+----------+------------------------------------------------------+
# |0.0|[0.0, 3.0]|[0.0,2,WrappedArray(3.0, 0.0),Map(0.0 -> 0, 3.0 -> 1)]|
# |1.0|[6.0, 9.0]|[6.0,2,WrappedArray(9.0, 6.0),Map(9.0 -> 1, 6.0 -> 0)]|
# +---+----------+------------------------------------------------------+

Passing multiple arguments to a UDF:

df = get_df()
func = udf(lambda arr: arr[0]*arr[1],FloatType())
df = df.withColumn('x*y', func(F.array('x', 'y')))

    # +---+---+---+
    # |  x|  y|x*y|
    # +---+---+---+
    # |0.0|0.0|0.0|
    # |0.0|3.0|0.0|
    # |1.0|6.0|6.0|
    # |1.0|9.0|9.0|
    # +---+---+---+

The code is purely for demo purposes, all above transformation are available in Spark code and would yield much better performance. As @zero323 in the comment above, UDFs should generally be avoided in pyspark; returning complex types should make you think about simplifying your logic.


For the scala version instead of python. version 2.4

import org.apache.spark.sql.types._

val testschema : StructType = StructType(
    StructField("number", IntegerType) ::
    StructField("Array",  ArrayType(StructType(StructField("cnt_rnk", IntegerType) :: StructField("comp", StringType) :: Nil))) :: 
    StructField("comp", StringType):: Nil)

The tree structure looks like this.

testschema.printTreeString
root
 |-- number: integer (nullable = true)
 |-- Array: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- cnt_rnk: integer (nullable = true)
 |    |    |-- corp_id: string (nullable = true)
 |-- comp: string (nullable = true)

There is no such thing as a TupleType in Spark. Product types are represented as structs with fields of specific type. For example if you want to return an array of pairs (integer, string) you can use schema like this:

from pyspark.sql.types import *

schema = ArrayType(StructType([
    StructField("char", StringType(), False),
    StructField("count", IntegerType(), False)
]))

Example usage:

from pyspark.sql.functions import udf
from collections import Counter

char_count_udf = udf(
    lambda s: Counter(s).most_common(),
    schema
)

df = sc.parallelize([(1, "foo"), (2, "bar")]).toDF(["id", "value"])

df.select("*", char_count_udf(df["value"])).show(2, False)

## +---+-----+-------------------------+
## |id |value|PythonUDF#<lambda>(value)|
## +---+-----+-------------------------+
## |1  |foo  |[[o,2], [f,1]]           |
## |2  |bar  |[[r,1], [a,1], [b,1]]    |
## +---+-----+-------------------------+