Rename nested field in spark dataframe

I found a much easier way than the one provided by @zero323, along the lines of @MaxPY:

Pyspark 2.4:

# Get the schema from the dataframe df
schema = df.schema

# Override `fields` with a list of new StructField, equals to the previous but for the names
schema.fields = (list(map(lambda field: 
                          StructField(field.name + "_renamed", field.dataType), schema.fields)))

# Override also `names` with the same mechanism
schema.names = list(map(lambda name: name + "_renamed", table_schema.names))

Now df.schema will print all the renewed names.


Python

It is not possible to modify a single nested field. You have to recreate a whole structure. In this particular case the simplest solution is to use cast.

First a bunch of imports:

from collections import namedtuple
from pyspark.sql.functions import col
from pyspark.sql.types import (
    ArrayType, LongType, StringType, StructField, StructType)

and example data:

Record = namedtuple("Record", ["a", "b", "c"])

df = sc.parallelize([([Record("foo", 1, 3)], )]).toDF(["array_field"])

Let's confirm that the schema is the same as in your case:

df.printSchema()
root
 |-- array_field: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- a: string (nullable = true)
 |    |    |-- b: long (nullable = true)
 |    |    |-- c: long (nullable = true)

You can define a new schema for example as a string:

str_schema = "array<struct<a_renamed:string,b:bigint,c:bigint>>"

df.select(col("array_field").cast(str_schema)).printSchema()
root
 |-- array_field: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- a_renamed: string (nullable = true)
 |    |    |-- b: long (nullable = true)
 |    |    |-- c: long (nullable = true)

or a DataType:

struct_schema = ArrayType(StructType([
    StructField("a_renamed", StringType()),
    StructField("b", LongType()),
    StructField("c", LongType())
]))

 df.select(col("array_field").cast(struct_schema)).printSchema()
root
 |-- array_field: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- a_renamed: string (nullable = true)
 |    |    |-- b: long (nullable = true)
 |    |    |-- c: long (nullable = true)

Scala

The same techniques can be used in Scala:

case class Record(a: String, b: Long, c: Long)

val df = Seq(Tuple1(Seq(Record("foo", 1, 3)))).toDF("array_field")

val strSchema = "array<struct<a_renamed:string,b:bigint,c:bigint>>"

df.select($"array_field".cast(strSchema))

or

import org.apache.spark.sql.types._

val structSchema = ArrayType(StructType(Seq(
    StructField("a_renamed", StringType),
    StructField("b", LongType),
    StructField("c", LongType)
)))

df.select($"array_field".cast(structSchema))

Possible improvements:

If you use an expressive data manipulation or JSON processing library it could be easier to dump data types to dict or JSON string and take it from there for example (Python / toolz):

from toolz.curried import pipe, assoc_in, update_in, map
from operator import attrgetter

# Update name to "a_updated" if name is "a"
rename_field = update_in(
    keys=["name"], func=lambda x: "a_updated" if x == "a" else x)

updated_schema = pipe(
   #  Get schema of the field as a dict
   df.schema["array_field"].jsonValue(),
   # Update fields with rename
   update_in(
       keys=["type", "elementType", "fields"],
       func=lambda x: pipe(x, map(rename_field), list)),
   # Load schema from dict
   StructField.fromJson,
   # Get data type
   attrgetter("dataType"))

df.select(col("array_field").cast(updated_schema)).printSchema()

You can recurse over the data frame's schema to create a new schema with the required changes.

A schema in PySpark is a StructType which holds a list of StructFields and each StructField can hold some primitve type or another StructType.

This means that we can decide if we want to recurse based on whether the type is a StructType or not.

Below is an annotated sample implementation that shows you how you can implement the above idea.

# Some imports
from pyspark.sql.types import DataType, StructType, ArrayType
from copy import copy

# We take a dataframe and return a new one with required changes
def cleanDataFrame(df: DataFrame) -> DataFrame:
    # Returns a new sanitized field name (this function can be anything really)
    def sanitizeFieldName(s: str) -> str:
        return s.replace("-", "_").replace("&", "_").replace("\"", "_")\
            .replace("[", "_").replace("]", "_").replace(".", "_")
    
    # We call this on all fields to create a copy and to perform any 
    # changes we might want to do to the field.
    def sanitizeField(field: StructField) -> StructField:
        field = copy(field)
        field.name = sanitizeFieldName(field.name)
        # We recursively call cleanSchema on all types
        field.dataType = cleanSchema(field.dataType)
        return field
    
    def cleanSchema(dataType: [DataType]) -> [DataType]:
        dataType = copy(dataType)
        # If the type is a StructType we need to recurse otherwise 
        # we can return since we've reached the leaf node
        if isinstance(dataType, StructType):
            # We call our sanitizer for all top level fields
            dataType.fields = [sanitizeField(f) for f in dataType.fields]
        elif isinstance(dataType, ArrayType):
            dataType.elementType = cleanSchema(dataType.elementType)
        return dataType

    # Now since we have the new schema we can create a new DataFrame 
    # by using the old Frame's RDD as data and the new schema as the 
    # schema for the data
    return spark.createDataFrame(df.rdd, cleanSchema(df.schema))