Generating unique, ordered Pythagorean triplets

Substantially faster than any of the solutions so far. Finds triplets via a ternary tree.

Wolfram says:

Hall (1970) and Roberts (1977) prove that is a primitive Pythagorean triple if and only if

(a,b,c)=(3,4,5)M

where M is a finite product of the matrices U,A,D.

And there we have a formula to generate every primitive triple.

In the above formula, the hypotenuse is ever growing so it's pretty easy to check for a max length.

In Python:

import numpy as np

def gen_prim_pyth_trips(limit=None):
    u = np.mat(' 1  2  2; -2 -1 -2; 2 2 3')
    a = np.mat(' 1  2  2;  2  1  2; 2 2 3')
    d = np.mat('-1 -2 -2;  2  1  2; 2 2 3')
    uad = np.array([u, a, d])
    m = np.array([3, 4, 5])
    while m.size:
        m = m.reshape(-1, 3)
        if limit:
            m = m[m[:, 2] <= limit]
        yield from m
        m = np.dot(m, uad)

If you'd like all triples and not just the primitives:

def gen_all_pyth_trips(limit):
    for prim in gen_prim_pyth_trips(limit):
        i = prim
        for _ in range(limit//prim[2]):
            yield i
            i = i + prim

list(gen_prim_pyth_trips(10**4)) took 2.81 milliseconds to come back with 1593 elements while list(gen_all_pyth_trips(10**4)) took 19.8 milliseconds to come back with 12471 elements

For reference, the accepted answer (in python) took 38 seconds for 12471 elements.

Just for fun, setting the upper limit to one million list(gen_all_pyth_trips(10**6)) returns in 2.66 seconds with 1980642 elements (almost 2 million triples in 3 seconds). list(gen_all_pyth_trips(10**7)) brings my computer to its knees as the list gets so large it consumes every last bit of ram. Doing something like sum(1 for _ in gen_all_pyth_trips(10**7)) gets around that limitation and returns in 30 seconds with 23471475 elements.

For more information on the algorithm used, check out the articles on Wolfram and Wikipedia.


Pythagorean Triples make a good example for claiming "for loops considered harmful", because for loops seduce us into thinking about counting, often the most irrelevant part of a task.

(I'm going to stick with pseudo-code to avoid language biases, and to keep the pseudo-code streamlined, I'll not optimize away multiple calculations of e.g. x * x and y * y.)

Version 1:

for x in 1..N {
    for y in 1..N {
        for z in 1..N {
            if x * x + y * y == z * z then {
                // use x, y, z
            }
        }
    }
}

is the worst solution. It generates duplicates, and traverses parts of the space that aren't useful (e.g. whenever z < y). Its time complexity is cubic on N.

Version 2, the first improvement, comes from requiring x < y < z to hold, as in:

for x in 1..N {
    for y in x+1..N {
        for z in y+1..N {
            if x * x + y * y == z * z then {
                // use x, y, z
            }
        }
    }
}

which reduces run time and eliminates duplicated solutions. However, it is still cubic on N; the improvement is just a reduction of the co-efficient of N-cubed.

It is pointless to continue examining increasing values of z after z * z < x * x + y * y no longer holds. That fact motivates Version 3, the first step away from brute-force iteration over z:

for x in 1..N {
    for y in x+1..N {
        z = y + 1
        while z * z < x * x + y * y {
            z = z + 1
        }
        if z * z == x * x + y * y and z <= N then {
            // use x, y, z
        }
    }
}

For N of 1000, this is about 5 times faster than Version 2, but it is still cubic on N.

The next insight is that x and y are the only independent variables; z depends on their values, and the last z value considered for the previous value of y is a good starting search value for the next value of y. That leads to Version 4:

for x in 1..N {
    y = x+1
    z = y+1
    while z <= N {
        while z * z < x * x + y * y {
            z = z + 1
        }
        if z * z == x * x + y * y and z <= N then {
            // use x, y, z
        }
        y = y + 1
    }
}

which allows y and z to "sweep" the values above x only once. Not only is it over 100 times faster for N of 1000, it is quadratic on N, so the speedup increases as N grows.

I've encountered this kind of improvement often enough to be mistrustful of "counting loops" for any but the most trivial uses (e.g. traversing an array).

Update: Apparently I should have pointed out a few things about V4 that are easy to overlook.

  1. Both of the while loops are controlled by the value of z (one directly, the other indirectly through the square of z). The inner while is actually speeding up the outer while, rather than being orthogonal to it. It's important to look at what the loops are doing, not merely to count how many loops there are.

  2. All of the calculations in V4 are strictly integer arithmetic. Conversion to/from floating-point, as well as floating-point calculations, are costly by comparison.

  3. V4 runs in constant memory, requiring only three integer variables. There are no arrays or hash tables to allocate and initialize (and, potentially, to cause an out-of-memory error).

  4. The original question allowed all of x, y, and x to vary over the same range. V1..V4 followed that pattern.

Below is a not-very-scientific set of timings (using Java under Eclipse on my older laptop with other stuff running...), where the "use x, y, z" was implemented by instantiating a Triple object with the three values and putting it in an ArrayList. (For these runs, N was set to 10,000, which produced 12,471 triples in each case.)

Version 4:           46 sec.
using square root:  134 sec.
array and map:      400 sec.

The "array and map" algorithm is essentially:

squares = array of i*i for i in 1 .. N
roots = map of i*i -> i for i in 1 .. N
for x in 1 .. N
    for y in x+1 .. N
        z = roots[squares[x] + squares[y]]
        if z exists use x, y, z

The "using square root" algorithm is essentially:

for x in 1 .. N
    for y in x+1 .. N
        z = (int) sqrt(x * x + y * y)
        if z * z == x * x + y * y then use x, y, z

The actual code for V4 is:

public Collection<Triple> byBetterWhileLoop() {
    Collection<Triple> result = new ArrayList<Triple>(limit);
    for (int x = 1; x < limit; ++x) {
        int xx = x * x;
        int y = x + 1;
        int z = y + 1;
        while (z <= limit) {
            int zz = xx + y * y;
            while (z * z < zz) {++z;}
            if (z * z == zz && z <= limit) {
                result.add(new Triple(x, y, z));
            }
            ++y;
        }
    }
    return result;
}

Note that x * x is calculated in the outer loop (although I didn't bother to cache z * z); similar optimizations are done in the other variations.

I'll be glad to provide the Java source code on request for the other variations I timed, in case I've mis-implemented anything.

Tags:

Python

Math