# Solving the Nearest Neighbor Problem using Python

By John Lekberg on April 17, 2020.

This week's post is about solving the "Nearest Neighbor Problem". You will learn:

# Problem Statement

This problem deals with points in real coordinate space.

You are given a collection of "reference points", e.g.

``````[ (1, 2), (3, 2), (4, 1), (3, 5) ]
``````

And you are a given a collection of "query points", e.g.

``````[ (3, 4), (5, 1), (7, 3), (8, 9), (10, 1), (3, 3) ]
``````

You goal is to find the nearest reference point for every query point. E.g.

For query point (3, 4), the closest reference point is (3, 5).

# How I represent the data in the problem

I represent a point in real coordinate space as a tuple object. E.g.

``````(3, 4)
``````

To compare distances (to find the nearest point) I use squared Euclidean distance (SED):

``````def SED(X, Y):
"""Compute the squared Euclidean distance between X and Y."""
return sum((i-j)**2 for i, j in zip(X, Y))

SED( (3, 4), (4, 9) )
``````
``````26
``````

SED ranks distances the same as Euclidean distance, so it is acceptable to use SED to find a "nearest neighbor".

I represent the solution as a dictionary object that maps query points to the nearest reference point. E.g.

``````{ (5, 1): (4, 1) }
``````

# Creating a brute force solution

A brute force solution to the "Nearest Neighbor Problem" will, for each query point, measure the distance (using SED) to every reference point and select the closest reference point:

``````def nearest_neighbor_bf(*, query_points, reference_points):
"""Use a brute force algorithm to solve the
"Nearest Neighbor Problem".
"""
return {
query_p: min(
reference_points,
key=lambda X: SED(X, query_p),
)
for query_p in query_points
}

reference_points = [ (1, 2), (3, 2), (4, 1), (3, 5) ]
query_points = [
(3, 4), (5, 1), (7, 3), (8, 9), (10, 1), (3, 3)
]

nearest_neighbor_bf(
reference_points = reference_points,
query_points = query_points,
)
``````
``````{(3, 4): (3, 5),
(5, 1): (4, 1),
(7, 3): (4, 1),
(8, 9): (3, 5),
(10, 1): (4, 1),
(3, 3): (3, 2)}
``````

What is the time complexity of this solution? For N query points and M reference points:

• Finding the closest reference point for a given query point takes O(M) steps.
• The algorithm has to find the closest reference point for O(N) query points.

As a result, the overall time complexity of the brute force algorithm is

O(N M).

How could I solve this problem faster?

• I alway iterate over O(N) query points, so I cannot reduce the N factor.
• However, maybe I can find the nearest reference point is less than O(M) steps (faster than checking every reference point).

# Using a spatial index to create a faster solution

A spatial index is a data structure used to optimize spacial queries. E.g.

• What is the nearest reference point for a query point?
• What reference points are within a 1 meter radius of a query point?

A k-dimensional tree (k-d tree) is a spatial index that uses a binary tree to divide up real coordinate space. Why should I use a k-d tree to solve the "Nearest Neighbor Problem"?

• For M reference points, searching for the nearest neighbor of a query point takes, on average, O(log M) time. This is faster than the O(M) time of the brute force algorithm.

I construct a balanced k-d tree using an algorithm outlined on the k-d tree Wikipedia page. This is a recursive algorithm that divides the set of reference points in half at the median and recursively builds left- and right-subtrees.

Here's a Python function, `kdtree`, that implements the construction algorithm:

``````import collections
import operator

BT = collections.namedtuple("BT", ["value", "left", "right"])
BT.__doc__ = """
A Binary Tree (BT) with a node value, and left- and
right-subtrees.
"""

def kdtree(points):
"""Construct a k-d tree from an iterable of points.

This algorithm is taken from Wikipedia. For more details,

> https://en.wikipedia.org/wiki/K-d_tree#Construction

"""
k = len(points)

def build(*, points, depth):
"""Build a k-d tree from a set of points at a given
depth.
"""
if len(points) == 0:
return None

points.sort(key=operator.itemgetter(depth % k))
middle = len(points) // 2

return BT(
value = points[middle],
left = build(
points=points[:middle],
depth=depth+1,
),
right = build(
points=points[middle+1:],
depth=depth+1,
),
)

return build(points=list(points), depth=0)

reference_points = [ (1, 2), (3, 2), (4, 1), (3, 5) ]
kdtree(reference_points)
``````
``````BT(value=(3, 5),
left=BT(value=(3, 2),
left=BT(value=(1, 2), left=None, right=None),
right=None),
right=BT(value=(4, 1), left=None, right=None))
``````

What is the time complexity of constructing a k-d tree from M points?

• Python's timsort runs in O(M log M) time.

• The construction recursively builds the left- and right-subtrees, which involves sorting two lists of half the size of the original.

2 O(½M log ½M) ≤ O(M log M)

As a result, each level of the tree takes O(M log M) time to build.

• Because the list of points is halved at each level, there are O(log M) levels in the k-d tree.

As a result, the overall time complexity of constructing the k-d tree is

O(M [log M]2)

For the nearest neighbor search, I use an algorithm outlined on the k-d tree Wikipedia page. This algorithm is a variation on searching a binary search tree.

Here is a Python function, `find_nearest_neighbor`, that implements this search algorithm:

``````NNRecord = collections.namedtuple("NNRecord", ["point", "distance"])
NNRecord.__doc__ = """
Used to keep track of the current best guess during a nearest
neighbor search.
"""

def find_nearest_neighbor(*, tree, point):
"""Find the nearest neighbor in a k-d tree for a given
point.
"""
k = len(point)

best = None
def search(*, tree, depth):
"""Recursively search through the k-d tree to find the
nearest neighbor.
"""
nonlocal best

if tree is None:
return

distance = SED(tree.value, point)
if best is None or distance < best.distance:
best = NNRecord(point=tree.value, distance=distance)

axis = depth % k
diff = point[axis] - tree.value[axis]
if diff <= 0:
close, away = tree.left, tree.right
else:
close, away = tree.right, tree.left

search(tree=close, depth=depth+1)
if diff**2 < best.distance:
search(tree=away, depth=depth+1)

search(tree=tree, depth=0)
return best.point

reference_points = [ (1, 2), (3, 2), (4, 1), (3, 5) ]
tree = kdtree(reference_points)
find_nearest_neighbor(tree=tree, point=(10, 1))
``````
``````(4, 1)
``````

The average time complexity of a nearest neighbor search in a balanced k-d tree of M points is:

O(log M)

I combine `kdtree` and `find_nearest_neighbor` to create a new solution to the "Nearest Neighbor Problem":

``````def nearest_neighbor_kdtree(*, query_points, reference_points):
"""Use a k-d tree to solve the "Nearest Neighbor Problem"."""
tree = kdtree(reference_points)
return {
query_p: find_nearest_neighbor(tree=tree, point=query_p)
for query_p in query_points
}

reference_points = [ (1, 2), (3, 2), (4, 1), (3, 5) ]
query_points = [
(3, 4), (5, 1), (7, 3), (8, 9), (10, 1), (3, 3)
]

nearest_neighbor_kdtree(
reference_points = reference_points,
query_points = query_points,
)
``````
``````{(3, 4): (3, 5),
(5, 1): (4, 1),
(7, 3): (4, 1),
(8, 9): (3, 5),
(10, 1): (4, 1),
(3, 3): (3, 2)}
``````
``````nn_kdtree = nearest_neighbor_kdtree(
reference_points = reference_points,
query_points = query_points,
)
nn_bf = nearest_neighbor_bf(
reference_points = reference_points,
query_points = query_points,
)
nn_kdtree == nn_bf
``````
``````True
``````

What is the time complexity of this algorithm? For N query points and M reference points:

• Building the k-d tree takes O(M [log M]2]) time.
• Each nearest neighbor search takes O(log M) time.
• O(N) nearest neighbor searches are conducted.

As a result, the overall time complexity of `nearest_neighbor_kdtree` is

O(M [log M]2 + N log M)

This seems faster than the brute force algorithm, but it is hard for me to think of simple examples of this. Instead, I will give empirical measurements.

# Confirming that the k-d tree algorithm and the brute force algorithm give the same results

When working with complex algorithms like this, it's easy for me to make a mistake. To reduce my anxiety that I made a mistake with the k-d tree algorithm, I will generate test data and make sure that the results match the results of the brute force algorithm:

``````import random

random_point = lambda: (random.random(), random.random())
reference_points = [ random_point() for _ in range(3000) ]
query_points = [ random_point() for _ in range(3000) ]

solution_bf = nearest_neighbor_bf(
reference_points = reference_points,
query_points = query_points
)
solution_kdtree = nearest_neighbor_kdtree(
reference_points = reference_points,
query_points = query_points
)

solution_bf == solution_kdtree
``````
``````True
``````

Before this test, I was already sure that the brute force algorithm is correct. Running this test makes me confident that my k-d tree algorithm is also correct.

# Comparing the measured speed of the k-d tree algorithm and the brute force algorithm

I generate some test data and use the cProfile module to measure:

``````import cProfile

reference_points = [ random_point() for _ in range(4000) ]
query_points = [ random_point() for _ in range(4000) ]

cProfile.run("""
nearest_neighbor_bf(
reference_points=reference_points,
query_points=query_points,
)
""")
``````
``````96004005 function calls in 26.252 seconds

...
``````
``````cProfile.run("""
nearest_neighbor_kdtree(
reference_points=reference_points,
query_points=query_points,
)
""")
``````
``````516215 function calls (422736 primitive calls) in 0.231 seconds

...
``````

The brute force algorithm takes 26.252 seconds and the k-d tree algorithm takes 0.231 seconds (over 100 times faster).

# In conclusion...

In this week's post, you learned how to solve the "Nearest Neighbor Problem" efficiently using a k-d tree (a kind of spatial index). K-d trees allow you to efficiently query large sets of spatial data to find close points. K-d trees and other spatial indices are used in databases to optimize queries.

My challenge to you:

I wrote a previous post, "Building a command line tool to simulate the spread of an infection", that showed you how to build a simulation of an infection spreading in a population.

Determining if a susceptible person gets infected requires finding the nearest infected person. The algorithm in that post used to find a close infected person is a brute force algorithm.

My challenge to you is to replace the brute force algorithm in the infection simulator with the k-d tree algorithm that we built in this post.

If you enjoyed this week's post, share it with your friends and stay tuned for next week's post. See you then!