See all coding puzzles and patterns here.

Problem Definition

Given the root of a binary tree, the value of target node target, and an integer k, find all nodes that are k edges away from the target.


Node Class

This is the Node class we'll use to represent each element in the binary tree.

class Node:
    def __init__(self, val: str, left: 'Node' = None, right: 'Node' = None):
        self.val = val
        self.left = left
        self.right = right

Examples

Some test cases to verify our solution.

# Our test tree:
#        a
#      /   \
#     b     c
#    /\    / \
#   d  e  f   g
#  /  /  / \  |\
# h  i  j   k l m
h, i, j, k, l, m = Node("h"), Node("i"), Node("j"), Node("k"), Node("l"), Node("m")
d, e, f, g = Node("d", left=h), Node("e", left=i), Node("f", left=j, right=k), Node("g", left=l, right=m)
b, c = Node("b", left=d, right=e), Node("c", left=f, right=g)
a = Node("a", left=b, right=c)
tree = a

assert k_dist_nodes(root=tree, target=e, k=4) == {f, g}
assert k_dist_nodes(root=tree, target=e, k=1) == {b, i}
assert k_dist_nodes(root=tree, target=b, k=2) == {h, i, c}
assert k_dist_nodes(root=tree, target=a, k=3) == {h, i, j, k, l, m}

Solution

We traverse once to find target and then traverse again to find all nodes that are k away from the target. We can use a DFS or a BFS during the first search, but for the second one, we have to do a BFS of exactly k hops from the target.

Essentially, we'll be constructing a new tree with the root at target and then doing a BFS to find all nodes k hops away from the root target. During the first traversal, we'll keep track of all parents and then traverse through the parents of target and do a BFS to find the nodes k away from target.

from collections import deque

def k_dist_nodes(root: Node, target: Node, k: int) -> set[Node]:
    q = deque([root])
    parent = {}
    while q:
        for _ in range(len(q)):
            node = q.popleft()
            if node.left is not None:
                parent[node.left] = node
                q.append(node.left)
            if node.right is not None:
                parent[node.right] = node
                q.append(node.right)
    visited = set()
    q.append(target)
    while q and k > 0:
        for _ in range(len(q)):
            node = q.pop()
            visited.add(node)
            if node.left is not None and node.left not in visited:
                q.append(node.left)
            if node.right is not None and node.right not in visited:
                q.append(node.right)
            if node in parent and parent[node] not in visited:
                q.append(parent[node])
        k -= 1
    return set(q)

Complexity

Let $n = $ the number of elements in the binary tree.

Since in the worst case, we may need to traverse the entire tree once to find the target and then traverse it once again to find the nodes that are k edges away, the runtime complexity of this solution should be $O(n)$.

In the worst case, we may store one level of tree in our q.

Therefore, auxiliary space complexity should be $O(n)$.