Solving the Maximum Sum Descent problem using Python
By John Lekberg on February 12, 2020.
This week's post is about solving an interview question: the "Maximum Sum Descent" problem. You will learn:
- how to solve this problem using a brute force algorithm.
- how to design a more efficient solution using dynamic programming.
Problem statement
Some positive integers are arranged in a triangle, e.g.
2
5 4
3 4 7
1 6 9 6
Your goal is to design an algorithm that finds the largest sum in a descent from its apex to the base through a sequence of adjacent numbers, one number per level.
(I use the term path as shorthand for "a descent from its apex to the base through a sequence of adjacent numbers, one number per level".)
Here is an example path:
2
5 .
. 4 .
. . 9 .
Its sum is 20 (2+5+4+9).
How I represent the triangle in Python
I represent the triangle in Python as a list of lists:
T = [
[2],
[5, 4],
[3, 4, 7],
[1, 6, 9, 6],
]
T[r][n]
is the n
th node on row r
.
Every node T[r][n]
has two adjacent nodes on the next level: (except for nodes in the bottom row)
T[r+1][n]
T[r+1][n+1]
A brute force solution
A brute force algorithm can solve this problem by checking every path and returning the largest path-sum.
How many paths are there?
For a triangle with N rows, there are 2N-1 paths. Here's a short proof by induction:
- A triangle with 1 row has 1 path. (21-1=20=1).
- Let's say a triangle with N rows has 2N-1 paths. Because each node (except for the bottom row) has two adjacent nodes on the next level, when we add another row to the N row triangle, each of the 2N-1 paths can continue to the new bottom row in two ways. Thus an N+1 row triangle has 2N paths.
The specific triangle that I am solving the "Maximum Sum Descent" for has 4 rows, so there are 8 (24-1) possible paths. And this is a small enough number of paths to allow for a brute force solution.
How can I generate all paths for an N row triangle?
I represent a path as a sequence of indices
[(r1, n1), (r2, n2), (r3, n3), ...]
For example, this path
2
5 .
. 4 .
. . 9 .
is represented as
[(0, 0), (1, 0), (2, 1), (3, 2)]
Because the row indices count up from zero, I can also represent this path using enumerate:
list(enumerate([0, 0, 1, 2]))
[(0, 0), (1, 0), (2, 1), (3, 2)]
Now I need to generate n1
, n2
, ... in
[(r1, n1), (r2, n2), (r3, n3), ...]
Each node (except the last row) has two adjacent nodes on the next
level, so either n2
=n1
or n2
=n1
+1 (and similarly for n3
, n4
,
etc.).
This means that I can write the node indices as
[n1, n1 + b2, n1 + b2 + b3, ... ]
where b2
, b3
, ... are 1 or 0.
The triangle's first row has only one node, so n1
is always 0.
Here's a Python function, paths
, that implements this technique to
generate all paths for an N row triangle.
from itertools import product, accumulate def paths(N): """Generate all paths for an `N`-row triangle.""" B = product([0, 1], repeat=N-1) for choices in B: path = enumerate([0, *accumulate(choices)]) yield list(path) list(paths(4))
[[(0, 0), (1, 0), (2, 0), (3, 0)],
[(0, 0), (1, 0), (2, 0), (3, 1)],
[(0, 0), (1, 0), (2, 1), (3, 1)],
[(0, 0), (1, 0), (2, 1), (3, 2)],
[(0, 0), (1, 1), (2, 1), (3, 1)],
[(0, 0), (1, 1), (2, 1), (3, 2)],
[(0, 0), (1, 1), (2, 2), (3, 2)],
[(0, 0), (1, 1), (2, 2), (3, 3)]]
itertools.product is used to generate b2
,
b3
, .... itertools.accumulate takes a
cumulative sum of the sequence, so
[0, b2, b3, b4, ...]
becomes
[0, b2, b3 + b2, b4 + b3 + b2, ...]
The brute force solution will compute the sum along every path. Here's a Python function to compute the sum along a single path:
def path_sum(T, path):
"""Return the sum of values in `T` along `path`."""
return sum(T[r][n] for r, n in path)
And the brute force solution to the Maximum Sum Descent problem is:
def msd_brute_force(T): """Solve the Maximum Sum Descent problem for `T`.""" N = len(T) return max(path_sum(T, p) for p in paths(N)) msd_brute_force(T)
22
But how well does the brute force solution scale? For a triangle with N rows:
- There are 2N-1 paths to check.
- The sum of each path is a sum of N values.
So the time complexity of msd_brute_force
is
O(N 2N)
This means that if solving a 4 row triangle took 1 second, then solving a 20 row triangle would take almost 4 days.
How can we solve this problem more efficiently?
A dynamic programming solution
Dynamic programming is a programming technique that solves complicated problems by first solving overlapping sub-problems and combining those solutions to solve the original problem. (See the conclusion of this post for more information on dynamic programming.)
To create a dynamic programming solution to the Maximum Sum Descent problem, I derive a recurrence relation for the maximum sum.
-
Let
T
be anN
row triangle. -
Let
S
be the maximum sum descending from a given node:S[r][n]
is the maximum sum descent starting from node(r,n)
. -
The goal is to compute the value of
S[0][0]
. -
The values of
S
for the bottom row ofT
(rowN-1
) are just the values of the bottom row ofT
, nothing more:S[N-1][n] = T[N-1][n]
-
For all other nodes
(r,n)
, the value ofS[r][n]
is the value ofT[r][n]
plus the greater value of:S[r+1][n]
- the maximum sum descending from the left child.S[r+1][n+1]
- the maximum sum descending from the right child.
So the recurrence relation is:
S[N-1][n] = T[N-1][n]
S[r][n] = T[r][n] + max(S[r+1][n], S[r+1][n+1])
Here's msd_dynamic_programming
, a Python implementation of a dynamic
programming solution that uses the above recurrence relation:
from copy import deepcopy def msd_dynamic_programming(T): """Solve the Maximum Sum Descent problem for `T`.""" N = len(T) S = deepcopy(T) for r in reversed(range(N-1)): for n in range(r+1): S[r][n] += max(S[r+1][n], S[r+1][n+1]) return S[0][0] msd_dynamic_programming(T)
22
I use deepcopy to initialize S
as a copy of the
triangle T
.
This copy covers the base case of
S[N-1][n] = T[N-1][n]
And I compute the recurrence relation over the rest of S
by iterating
over S
in reverse row order (N-2
, ..., 1, 0) because the values
of S
in row n
depend on the values of S
in row n+1
.
I use the built-in function reversed to reverse
range(N-1)
(which produces 0, 1, ..., N-2
).
How well does msd_dynamic_programming
scale?
For a triangle with N rows:
- I iterate over every node in the triangle (some nodes twice), so the time complexity is linear in the number of nodes.
How many nodes are there? There one node in the first row, two nodes in the second row, three nodes in the third row, etc. So the total number of nodes is:
1 + 2 + 3 + ... + N
These are known as triangular numbers (A000217). Here's a closed-form expression for the number of nodes in a triangle with N rows:
½(N2 + N)
Since the time complexity of msd_dynamic_programming
is linear in the
number of nodes, the time complexity for a triangle with
N rows is:
O(N2)
Even though the time complexity of the dynamic programming solution is much better than the time complexity of the brute force solution, there is a cost.
For a triangle with N rows, the space complexity of the brute force solution is
O(N)
because one path is processed at a time and each path has N entries. But the space complexity of the dynamic programming solution is
O(N2)
Because a copy of the triangle is made and the triangle has O(N2) nodes.
This is a common phenomenon with dynamic programming:
- The time complexity goes down. E.g. O(N 2N) to O(N2).
- The space complexity goes up. E.g. O(N) to O(N2).
For this problem, I think that the benefits of a faster algorithm outweigh the costs of needing more storage.
In conclusion...
In this post you learned how to use dynamic programming to solve a problem in quadratic time, instead of exponential time for the brute force solution. Dynamic programming can really reduce the time complexity of an algorithm, but usually requires more storage space. To learn more about dynamic programming, check out these documents:
- "A graphical introduction to dynamic programming" by Avik Das
- "dynamic programming" by the National Institute of Standards and Technology
- "Lecture 19: Dynamic Programming I: Fibonacci, Shortest Paths" by Erik Demaine (MIT)
My challenge to you:
In this post, I used the fact that the triangular numbers
1 + 2 + 3 + ... + N
have a closed form solution
½(N2 + N)
to prove that the time complexity of the dynamic programming solution is O(N2).
Show how to derive the closed form solution.
If you enjoyed this post, let me know. Share this with your friends and stay tuned for next week's post. See you then!
(If you spot any errors or typos on this post, contact me via my contact page.)