SPARK, DataFrame: difference of Timestamp columns over consecutive rows

Thanks to the hint of @lostInOverflow, I came up with the following solution:

import org.apache.spark.sql.functions._
import org.apache.spark.sql.expressions._

val w = Window.partitionBy("id").orderBy("initDate")
val previousEnd = lag($"endDate", 1).over(w)
filteredDF.withColumn("prev", previousEnd)
          .withColumn("difference", datediff($"initDate", $"prev"))

Try:

import org.apache.spark.sql.functions._
import org.apache.spark.sql.expressions._

val w = Window.partitionBy("id").orderBy("endDate")

df.withColumn("difference", date_sub($"initDate", lag($"endDate", 1).over(w)))