GroupBy column and filter rows with maximum value in Pyspark

Another possible approach is to apply join the dataframe with itself specifying "leftsemi". This kind of join includes all columns from the dataframe on the left side and no columns on the right side.

For example:

import pyspark.sql.functions as f
data = [
    ('a', 5, 'c'),
    ('a', 8, 'd'),
    ('a', 7, 'e'),
    ('b', 1, 'f'),
    ('b', 3, 'g')
]
df = sqlContext.createDataFrame(data, ["A", "B", "C"])
df.show()
+---+---+---+
|  A|  B|  C|
+---+---+---+
|  a|  5|  c|
|  a|  8|  d|
|  a|  7|  e|
|  b|  1|  f|
|  b|  3|  g|
+---+---+---+

Max value of column B by by column A can be selected doing:

df.groupBy('A').agg(f.max('B')
+---+---+
|  A|  B|
+---+---+
|  a|  8|
|  b|  3|
+---+---+

Using this expression as a right side in a left semi join, and renaming the obtained column max(B) back to its original name B, we can obtain the result needed:

df.join(df.groupBy('A').agg(f.max('B').alias('B')),on='B',how='leftsemi').show()
+---+---+---+
|  B|  A|  C|
+---+---+---+
|  3|  b|  g|
|  8|  a|  d|
+---+---+---+

The physical plan behind this solution and the one from accepted answer are different and it is still not clear to me which one will perform better on large dataframes.

The same result can be obtained using spark SQL syntax doing:

df.registerTempTable('table')
q = '''SELECT *
FROM table a LEFT SEMI
JOIN (
    SELECT 
        A,
        max(B) as max_B
    FROM table
    GROUP BY A
    ) t
ON a.A=t.A AND a.B=t.max_B
'''
sqlContext.sql(q).show()
+---+---+---+
|  A|  B|  C|
+---+---+---+
|  b|  3|  g|
|  a|  8|  d|
+---+---+---+

You can do this without a udf using a Window.

Consider the following example:

import pyspark.sql.functions as f
data = [
    ('a', 5),
    ('a', 8),
    ('a', 7),
    ('b', 1),
    ('b', 3)
]
df = sqlCtx.createDataFrame(data, ["A", "B"])
df.show()
#+---+---+
#|  A|  B|
#+---+---+
#|  a|  5|
#|  a|  8|
#|  a|  7|
#|  b|  1|
#|  b|  3|
#+---+---+

Create a Window to partition by column A and use this to compute the maximum of each group. Then filter out the rows such that the value in column B is equal to the max.

from pyspark.sql import Window
w = Window.partitionBy('A')
df.withColumn('maxB', f.max('B').over(w))\
    .where(f.col('B') == f.col('maxB'))\
    .drop('maxB')\
    .show()
#+---+---+
#|  A|  B|
#+---+---+
#|  a|  8|
#|  b|  3|
#+---+---+

Or equivalently using pyspark-sql:

df.registerTempTable('table')
q = "SELECT A, B FROM (SELECT *, MAX(B) OVER (PARTITION BY A) AS maxB FROM table) M WHERE B = maxB"
sqlCtx.sql(q).show()
#+---+---+
#|  A|  B|
#+---+---+
#|  b|  3|
#|  a|  8|
#+---+---+