Algorithm to print all paths with a given sum in a binary tree

Here's an O(n + numResults) answer (essentially the same as @Somebody's answer, but with all issues resolved):

  1. Do a pre-order, in-order, or post-order traversal of the binary tree.
  2. As you do the traversal, maintain the cumulative sum of node values from the root node to the node above the current node. Let's call this value cumulativeSumBeforeNode.
  3. When you visit a node in the traversal, add it to a hashtable at key cumulativeSumBeforeNode (the value at that key will be a list of nodes).
  4. Compute the difference between cumulativeSumBeforeNode and the target sum. Look up this difference in the hash table.
  5. If the hash table lookup succeeds, it should produce a list of nodes. Each one of those nodes represents the start node of a solution. The current node represents the end node for each corresponding start node. Add each [start node, end node] combination to your list of answers. If the hash table lookup fails, do nothing.
  6. When you've finished visiting a node in the traversal, remove the node from the list stored at key cumulativeSumBeforeNode in the hash table.

Code:

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

public class BinaryTreePathsWithSum {
    public static void main(String[] args) {
        BinaryTreeNode a = new BinaryTreeNode(5);
        BinaryTreeNode b = new BinaryTreeNode(16);
        BinaryTreeNode c = new BinaryTreeNode(16);
        BinaryTreeNode d = new BinaryTreeNode(4);
        BinaryTreeNode e = new BinaryTreeNode(19);
        BinaryTreeNode f = new BinaryTreeNode(2);
        BinaryTreeNode g = new BinaryTreeNode(15);
        BinaryTreeNode h = new BinaryTreeNode(91);
        BinaryTreeNode i = new BinaryTreeNode(8);

        BinaryTreeNode root = a;
        a.left = b;
        a.right = c;
        b.right = e;
        c.right = d;
        e.left = f;
        f.left = g;
        f.right = h;
        h.right = i;

        /*
                5
              /   \
            16     16
              \     \
              19     4
              /
             2
            / \
           15  91
                \
                 8
        */

        List<BinaryTreePath> pathsWithSum = getBinaryTreePathsWithSum(root, 112); // 19 => 2 => 91

        System.out.println(Arrays.toString(pathsWithSum.toArray()));
    }

    public static List<BinaryTreePath> getBinaryTreePathsWithSum(BinaryTreeNode root, int sum) {
        if (root == null) {
            throw new IllegalArgumentException("Must pass non-null binary tree!");
        }

        List<BinaryTreePath> paths = new ArrayList<BinaryTreePath>();
        Map<Integer, List<BinaryTreeNode>> cumulativeSumMap = new HashMap<Integer, List<BinaryTreeNode>>();

        populateBinaryTreePathsWithSum(root, 0, cumulativeSumMap, sum, paths);

        return paths;
    }

    private static void populateBinaryTreePathsWithSum(BinaryTreeNode node, int cumulativeSumBeforeNode, Map<Integer, List<BinaryTreeNode>> cumulativeSumMap, int targetSum, List<BinaryTreePath> paths) {
        if (node == null) {
            return;
        }

        addToMap(cumulativeSumMap, cumulativeSumBeforeNode, node);

        int cumulativeSumIncludingNode = cumulativeSumBeforeNode + node.value;
        int sumToFind = cumulativeSumIncludingNode - targetSum;

        if (cumulativeSumMap.containsKey(sumToFind)) {
            List<BinaryTreeNode> candidatePathStartNodes = cumulativeSumMap.get(sumToFind);

            for (BinaryTreeNode pathStartNode : candidatePathStartNodes) {
                paths.add(new BinaryTreePath(pathStartNode, node));
            }
        }

        populateBinaryTreePathsWithSum(node.left, cumulativeSumIncludingNode, cumulativeSumMap, targetSum, paths);
        populateBinaryTreePathsWithSum(node.right, cumulativeSumIncludingNode, cumulativeSumMap, targetSum, paths);

        removeFromMap(cumulativeSumMap, cumulativeSumBeforeNode);
    }

    private static void addToMap(Map<Integer, List<BinaryTreeNode>> cumulativeSumMap, int cumulativeSumBeforeNode, BinaryTreeNode node) {
        if (cumulativeSumMap.containsKey(cumulativeSumBeforeNode)) {
            List<BinaryTreeNode> nodes = cumulativeSumMap.get(cumulativeSumBeforeNode);
            nodes.add(node);
        } else {
            List<BinaryTreeNode> nodes = new ArrayList<BinaryTreeNode>();
            nodes.add(node);
            cumulativeSumMap.put(cumulativeSumBeforeNode, nodes);
        }
    }

    private static void removeFromMap(Map<Integer, List<BinaryTreeNode>> cumulativeSumMap, int cumulativeSumBeforeNode) {
        List<BinaryTreeNode> nodes = cumulativeSumMap.get(cumulativeSumBeforeNode);
        nodes.remove(nodes.size() - 1);
    }

    private static class BinaryTreeNode {
        public int value;
        public BinaryTreeNode left;
        public BinaryTreeNode right;

        public BinaryTreeNode(int value) {
            this.value = value;
        }

        public String toString() {
            return this.value + "";
        }

        public int hashCode() {
            return Integer.valueOf(this.value).hashCode();
        }

        public boolean equals(Object other) {
            return this == other;
        }
    }

    private static class BinaryTreePath {
        public BinaryTreeNode start;
        public BinaryTreeNode end;

        public BinaryTreePath(BinaryTreeNode start, BinaryTreeNode end) {
            this.start = start;
            this.end = end;
        }

        public String toString() {
            return this.start + " to " + this.end;
        }
    }
}

Here is an approach with nlogn complexity.

  1. Traverse the tree with inorder.
  2. At the same time maintain all the nodes along with the cumulative sum in a Hashmap<CumulativeSum, reference to the corresponding node>.
  3. Now at a given node calculate cumulative sum from root to till the node say this be SUM.
  4. Now look for the value SUM-K in the HashMap.
  5. If the entry exists take the corresponding node reference in the HashMap.
  6. Now we have a valid path from the node reference to the current node.

Well, this is a tree, not a graph. So, you can do something like this:

Pseudocode:

global ResultList

function ProcessNode(CurrentNode, CurrentSum)
    CurrentSum+=CurrentNode->Value
    if (CurrentSum==SumYouAreLookingFor) AddNodeTo ResultList
    for all Children of CurrentNode
          ProcessNode(Child,CurrentSum)

Well, this gives you the paths that start at the root. However, you can just make a tiny change:

    for all Children of CurrentNode
          ProcessNode(Child,CurrentSum)
          ProcessNode(Child,0)

You might need to think about it for a second (I'm busy with other things), but this should basically run the same algorithm rooted at every node in the tree

EDIT: this actually gives the "end node" only. However, as this is a tree, you can just start at those end nodes and walk back up until you get the required sum.

EDIT 2: and, of course, if all values are positive then you can abort the descent if your current sum is >= the required one


Based on Christian's answer above:

public void printSums(Node n, int sum, int currentSum, String buffer) {
     if (n == null) {
         return;
     }
     int newSum = currentSum + n.val;
     String newBuffer = buffer + " " + n.val;
     if (newSum == sum) {
         System.out.println(newBuffer);
     }
     printSums(n.left, sum, newSum, newBuffer);
     printSums(n.right, sum, newSum, newBuffer);
     printSums(n.left, sum, 0, "");
     printSums(n.right, sum, 0, "");
} 

printSums(root, targetSum, 0, "");

Tags:

Binary Tree