See all coding puzzles and patterns here.

Problem Definition

Given a binary tree, return the maximum sum of any of the subtrees, including the root tree.


Node Class


Example

A test case to verify our solution.

#      1
#    /   \
#  -2     3
#  / \   / \
# 4   7 1  -5

four = Node(4)
seven = Node(7)
other_one = Node(1)
five = Node(-5)
two = Node(-2, four, seven)
three = Node(3, other_one, five)
one = Node(1, two, three)

assert max_sum(one) == 9

Solution

We implement a recursive DFS solution.

import math

def helper(node) -> tuple[int, int]:
    left_sum, right_sum = 0, 0
    left_max, right_max = -math.inf, -math.inf
    if node.left:
        left_sum, left_max = helper(node.left)
    if node.right:
        right_sum, right_max = helper(node.right)
    curr_sum = left_sum + right_sum + node.value
    curr_max = max(curr_sum, left_max, right_max)
    return (curr_sum, curr_max)

def max_sum(root) -> int:
    return helper(root)[1]

Complexity

Let $n = $ the number of nodes in the tree.

Runtime complexity should be $\Theta(n)$ since we will visit every node once.

Auxiliary space complexity should be $O(n)$ since in the worst case, we will have $n$ stack frames, for the case where the tree is just a singly-linked list.