Efficient string suffix detection

Let's extend the domains for slightly better coverage:

domains = spark.createDataFrame([
    "something.google.com",  # OK
    "something.google.com.somethingelse.ac.uk", # NOT OK 
    "something.good.com.cy", # OK 
    "something.good.com.cy.mal.org",  # NOT OK
    "something.bad.com.cy",  # NOT OK
    "omgalsogood.com.cy", # NOT OK
    "good.com.cy",   # OK 
    "sogood.example.com",  # OK Match for shorter redundant, mismatch on longer
    "notsoreal.googleecom" # NOT OK
], "string").toDF('domains')

good_domains =  spark.createDataFrame([
    "google.com", "good.com.cy", "alsogood.com.cy",
    "good.example.com", "example.com"  # Redundant case
], "string").toDF('gooddomains')

Now... A naive solution, using only Spark SQL primitives, is to simplify your current approach a bit. Since you've stated that it is safe to assume that these are valid public domains, we can define a function like this:

from pyspark.sql.functions import col, regexp_extract

def suffix(c): 
    return regexp_extract(c, "([^.]+\\.[^.]+$)", 1) 

which extract top level domain and first level subdomain:

domains_with_suffix = (domains
    .withColumn("suffix", suffix("domains"))
    .alias("domains"))
good_domains_with_suffix = (good_domains
    .withColumn("suffix", suffix("gooddomains"))
    .alias("good_domains"))

domains_with_suffix.show()
+--------------------+--------------------+
|             domains|              suffix|
+--------------------+--------------------+
|something.google.com|          google.com|
|something.google....|               ac.uk|
|something.good.co...|              com.cy|
|something.good.co...|             mal.org|
|something.bad.com.cy|              com.cy|
|  omgalsogood.com.cy|              com.cy|
|         good.com.cy|              com.cy|
|  sogood.example.com|         example.com|
|notsoreal.googleecom|notsoreal.googleecom|
+--------------------+--------------------+

Now we can outer join:

from pyspark.sql.functions import (
    col, concat, lit, monotonically_increasing_id, sum as sum_
)

candidates = (domains_with_suffix
    .join(
        good_domains_with_suffix,
        col("domains.suffix") == col("good_domains.suffix"), 
        "left"))

and filter the result:

is_good_expr = (
    col("good_domains.suffix").isNotNull() &      # Match on suffix
    (

        # Exact match
        (col("domains") == col("gooddomains")) |
        # Subdomain match
        col("domains").endswith(concat(lit("."), col("gooddomains")))
    )
)

not_good_domains = (candidates
    .groupBy("domains")  # .groupBy("suffix", "domains") - see the discussion
    .agg((sum_(is_good_expr.cast("integer")) > 0).alias("any_good"))
    .filter(~col("any_good"))
    .drop("any_good"))

not_good_domains.show(truncate=False)     
+----------------------------------------+
|domains                                 |
+----------------------------------------+
|omgalsogood.com.cy                      |
|notsoreal.googleecom                    |
|something.good.com.cy.mal.org           |
|something.google.com.somethingelse.ac.uk|
|something.bad.com.cy                    |
+----------------------------------------+

This is better than a Cartesian product required for direct join with LIKE, but is unsatisfactory to brute-force and in the worst case scenario requires two shuffles - one for join (this can be skipped if good_domains are small enough to broadcasted), and the another one for group_by + agg.

Unfortunately Spark SQL doesn't allow custom partitioner to use only one shuffle for both (it is however possible with composite key in RDD API) and optimizer is not smart enough yet, to optimize join(_, "key1") and .groupBy("key1", _).

If you can accept some false negatives you can go probabilistic. First let's build probabilistic counter (here using bounter with small help from toolz)

from pyspark.sql.functions import concat_ws, reverse, split
from bounter import bounter
from toolz.curried import identity, partition_all

# This is only for testing on toy examples, in practice use more realistic value
size_mb = 20      
chunk_size = 100

def reverse_domain(c):
    return concat_ws(".", reverse(split(c, "\\.")))

def merge(acc, xs):
    acc.update(xs)
    return acc

counter = sc.broadcast((good_domains
    .select(reverse_domain("gooddomains"))
    .rdd.flatMap(identity)
    # Chunk data into groups so we reduce the number of update calls
    .mapPartitions(partition_all(chunk_size))
    # Use tree aggregate to reduce pressure on the driver, 
    # when number of partitions is large*
    # You can use depth parameter for further tuning
    .treeAggregate(bounter(need_iteration=False, size_mb=size_mb), merge, merge)))

next define an user defined function function like this

from pyspark.sql.functions import pandas_udf, PandasUDFType
from toolz import accumulate

def is_good_counter(counter):
    def is_good_(x):
        return any(
            x in counter.value 
            for x in accumulate(lambda x, y: "{}.{}".format(x, y), x.split("."))
        )

    @pandas_udf("boolean", PandasUDFType.SCALAR)
    def _(xs):
        return xs.apply(is_good_)
    return _

and filter the domains:

domains.filter(
    ~is_good_counter(counter)(reverse_domain("domains"))
).show(truncate=False)
+----------------------------------------+
|domains                                 |
+----------------------------------------+
|something.google.com.somethingelse.ac.uk|
|something.good.com.cy.mal.org           |
|something.bad.com.cy                    |
|omgalsogood.com.cy                      |
|notsoreal.googleecom                    |
+----------------------------------------+

In Scala this could be done with bloomFilter

import org.apache.spark.sql.Column
import org.apache.spark.sql.functions._
import org.apache.spark.util.sketch.BloomFilter

def reverseDomain(c: Column) = concat_ws(".", reverse(split(c, "\\.")))

val checker = good_domains.stat.bloomFilter(
  // Adjust values depending on the data
  reverseDomain($"gooddomains"), 1000, 0.001 
)

def isGood(checker: BloomFilter) = udf((s: String) => 
  s.split('.').toStream.scanLeft("") {
    case ("", x) => x
    case (acc, x) => s"${acc}.${x}"
}.tail.exists(checker mightContain _))


domains.filter(!isGood(checker)(reverseDomain($"domains"))).show(false)
+----------------------------------------+
|domains                                 |
+----------------------------------------+
|something.google.com.somethingelse.ac.uk|
|something.good.com.cy.mal.org           |
|something.bad.com.cy                    |
|omgalsogood.com.cy                      |
|notsoreal.googleecom                    |
+----------------------------------------+

and if needed, shouldn't be hard to call such code from Python.

This might be still not fully satisfying, due to approximate nature. If you require an exact result you can try to leverage redundant nature of the data, for example with trie (here using datrie implementation).

If good_domains are relatively small you can create a single model, in a similar way as in the probabilistic variant:

import string
import datrie


def seq_op(acc, x):
    acc[x] = True
    return acc

def comb_op(acc1, acc2):
    acc1.update(acc2)
    return acc1

trie = sc.broadcast((good_domains
    .select(reverse_domain("gooddomains"))
    .rdd.flatMap(identity)
    # string.printable is a bit excessive if you need standard domain
    # and not enough if you allow internationalized domain names.
    # In the latter case you'll have to adjust the `alphabet`
    # or use different implementation of trie.
    .treeAggregate(datrie.Trie(string.printable), seq_op, comb_op)))

define user defined function:

def is_good_trie(trie):
    def is_good_(x):
        if not x:
            return False
        else:
            return any(
                x == match or x[len(match)] == "."
                for match in trie.value.iter_prefixes(x)
            )

    @pandas_udf("boolean", PandasUDFType.SCALAR)
    def _(xs):
        return xs.apply(is_good_)

    return _

and apply it to the data:

domains.filter(
    ~is_good_trie(trie)(reverse_domain("domains"))
).show(truncate=False)
+----------------------------------------+
|domains                                 |
+----------------------------------------+
|something.google.com.somethingelse.ac.uk|
|something.good.com.cy.mal.org           |
|something.bad.com.cy                    |
|omgalsogood.com.cy                      |
|notsoreal.googleecom                    |
+----------------------------------------+

This specific approach works under assumption that all good_domains can be compressed into a single trie, but can be easily extended to handle cases where this assumption is not satisfied. For example you can build a single trie per top level domain or suffix (as defined in the naive solution)

(good_domains
    .select(suffix("gooddomains"), reverse_domain("gooddomains"))
    .rdd
    .aggregateByKey(datrie.Trie(string.printable), seq_op, comb_op))

and then, either load models on demand from serialized version, or use RDD operations.

The two non-native methods can be further adjusted depending on the data, business requirements (like false negative tolerance in case of approximate solution) and available resources (driver memory, executor memory, cardinality of suffixes, access to distributed POSIX-compliant distributed file system, and so on). There also some trade-offs to consider when choosing between applying these on DataFrames and RDDs (memory usage, communication and serialization overhead).


* See Understanding treeReduce() in Spark


If I understand correctly, you just want a left anti join using a simple SQL string matching pattern.

from pyspark.sql.functions import expr

dd.alias("l")\
    .join(
        dd1.alias("r"), 
        on=expr("l.domains LIKE concat('%', r.gooddomains)"), 
        how="leftanti"
    )\
    .select("l.*")\
    .show(truncate=False)
#+----------------------------------------+
#|domains                                 |
#+----------------------------------------+
#|something.google.com.somethingelse.ac.uk|
#|something.good.com.cy.mal.org           |
#+----------------------------------------+

The expression concat('%', r.gooddomains) prepends a wildcard to r.gooddomains.

Next, we use l.domains LIKE concat('%', r.gooddomains) to find the rows which match this pattern.

Finally, specify how="leftanti" in order to keep only the rows that don't match.


Update: As pointed out in the comments by @user10938362 there are 2 flaws with this approach:

1) Since this only looks at matching suffixes, there are edge cases where this produces the wrong results. For example:

example.com should match example.com and subdomain.example.com, but not fakeexample.com

There are two ways to approach this. The first is to modify the LIKE expression to handle this. Since we know these are all valid domains, we can check for an exact match or a dot followed by the domain:

like_expr = " OR ".join(
    [
        "(l.domains = r.gooddomains)",
        "(l.domains LIKE concat('%.', r.gooddomains))"
    ]
)

dd.alias("l")\
    .join(
        dd1.alias("r"), 
        on=expr(like_expr), 
        how="leftanti"
    )\
    .select("l.*")\
    .show(truncate=False)

Similarly, one can use RLIKE with a regular expression pattern with a look-behind.

2) The larger issue is that, as explained here, joining on a LIKE expression will cause a Cartesian Product. If dd1 is small enough to be broadcast, then this isn't an issue.

Otherwise, you may run into performance issues and will have to try a different approach.


More on the PySparkSQL LIKE operator from the Apache HIVE docs:

A LIKE B:

TRUE if string A matches the SQL simple regular expression B, otherwise FALSE. The comparison is done character by character. The _ character in B matches any character in A (similar to . in posix regular expressions), and the % character in B matches an arbitrary number of characters in A (similar to .* in posix regular expressions). For example, 'foobar' LIKE 'foo' evaluates to FALSE where as 'foobar' LIKE 'foo___' evaluates to TRUE and so does 'foobar' LIKE 'foo%'. To escape % use \ (% matches one % character). If the data contains a semicolon, and you want to search for it, it needs to be escaped, columnValue LIKE 'a\;b'


Note: This exploits the "trick" of using pyspark.sql.functions.expr to pass in a column value as a parameter to a function.