Adding a group count column to a PySpark dataframe

When you do a groupBy(), you have to specify the aggregation before you can display the results. For example:

import pyspark.sql.functions as f
data = [
    ('a', 5),
    ('a', 8),
    ('a', 7),
    ('b', 1),
]
df = sqlCtx.createDataFrame(data, ["x", "y"])
df.groupBy('x').count().select('x', f.col('count').alias('n')).show()
#+---+---+
#|  x|  n|
#+---+---+
#|  b|  1|
#|  a|  3|
#+---+---+

Here I used alias() to rename the column. But this only returns one row per group. If you want all rows with the count appended, you can do this with a Window:

from pyspark.sql import Window
w = Window.partitionBy('x')
df.select('x', 'y', f.count('x').over(w).alias('n')).sort('x', 'y').show()
#+---+---+---+
#|  x|  y|  n|
#+---+---+---+
#|  a|  5|  3|
#|  a|  7|  3|
#|  a|  8|  3|
#|  b|  1|  1|
#+---+---+---+

Or if you're more comfortable with SQL, you can register the dataframe as a temporary table and take advantage of pyspark-sql to do the same thing:

df.registerTempTable('table')
sqlCtx.sql(
    'SELECT x, y, COUNT(x) OVER (PARTITION BY x) AS n FROM table ORDER BY x, y'
).show()
#+---+---+---+
#|  x|  y|  n|
#+---+---+---+
#|  a|  5|  3|
#|  a|  7|  3|
#|  a|  8|  3|
#|  b|  1|  1|
#+---+---+---+

I found we can get even more close to the tidyverse example:

from pyspark.sql import Window
w = Window.partitionBy('x')
df.withColumn('n', f.count('x').over(w)).sort('x', 'y').show()

as @pault appendix

import pyspark.sql.functions as F

...

(df
.groupBy(F.col('x'))
.agg(F.count('x').alias('n'))
.show())

#+---+---+
#|  x|  n|
#+---+---+
#|  b|  1|
#|  a|  3|
#+---+---+

enjoy