Solving the Nearest Neighbor Problem using Python
By John Lekberg on April 17, 2020.
This week's post is about solving the "Nearest Neighbor Problem". You will learn:
 How to use a brute force algorithm to solve the problem.
 How to use a spatial index to create a faster solution.
Problem Statement
This problem deals with points in real coordinate space.
You are given a collection of "reference points", e.g.
[ (1, 2), (3, 2), (4, 1), (3, 5) ]
And you are a given a collection of "query points", e.g.
[ (3, 4), (5, 1), (7, 3), (8, 9), (10, 1), (3, 3) ]
You goal is to find the nearest reference point for every query point. E.g.
For query point (3, 4), the closest reference point is (3, 5).
How I represent the data in the problem
I represent a point in real coordinate space as a tuple object. E.g.
(3, 4)
To compare distances (to find the nearest point) I use squared Euclidean distance (SED):
def SED(X, Y): """Compute the squared Euclidean distance between X and Y.""" return sum((ij)**2 for i, j in zip(X, Y)) SED( (3, 4), (4, 9) )
26
SED ranks distances the same as Euclidean distance, so it is acceptable to use SED to find a "nearest neighbor".
I represent the solution as a dictionary object that maps query points to the nearest reference point. E.g.
{ (5, 1): (4, 1) }
Creating a brute force solution
A brute force solution to the "Nearest Neighbor Problem" will, for each query point, measure the distance (using SED) to every reference point and select the closest reference point:
def nearest_neighbor_bf(*, query_points, reference_points): """Use a brute force algorithm to solve the "Nearest Neighbor Problem". """ return { query_p: min( reference_points, key=lambda X: SED(X, query_p), ) for query_p in query_points } reference_points = [ (1, 2), (3, 2), (4, 1), (3, 5) ] query_points = [ (3, 4), (5, 1), (7, 3), (8, 9), (10, 1), (3, 3) ] nearest_neighbor_bf( reference_points = reference_points, query_points = query_points, )
{(3, 4): (3, 5),
(5, 1): (4, 1),
(7, 3): (4, 1),
(8, 9): (3, 5),
(10, 1): (4, 1),
(3, 3): (3, 2)}
What is the time complexity of this solution? For N query points and M reference points:
 Finding the closest reference point for a given query point takes O(M) steps.
 The algorithm has to find the closest reference point for O(N) query points.
As a result, the overall time complexity of the brute force algorithm is
O(N M).
How could I solve this problem faster?
 I alway iterate over O(N) query points, so I cannot reduce the N factor.
 However, maybe I can find the nearest reference point is less than O(M) steps (faster than checking every reference point).
Using a spatial index to create a faster solution
A spatial index is a data structure used to optimize spacial queries. E.g.
 What is the nearest reference point for a query point?
 What reference points are within a 1 meter radius of a query point?
A kdimensional tree (kd tree) is a spatial index that uses a binary tree to divide up real coordinate space. Why should I use a kd tree to solve the "Nearest Neighbor Problem"?
 For M reference points, searching for the nearest neighbor of a query point takes, on average, O(log M) time. This is faster than the O(M) time of the brute force algorithm.
For more information about kd trees, read these documents:
 "Using kd trees to efficiently calculate nearest neighbors in 3D vector space" by Jeremy Day
 "kd tree" on Wikipedia
 "kd tree", Dictionary of Algorithms and Data Structures, by Paul E. Black
 "kdTrees, CMSC 420" by Carl Kingsford
I construct a balanced kd tree using an algorithm outlined on the kd tree Wikipedia page. This is a recursive algorithm that divides the set of reference points in half at the median and recursively builds left and rightsubtrees.
Here's a Python function, kdtree
, that implements the construction algorithm:
import collections import operator BT = collections.namedtuple("BT", ["value", "left", "right"]) BT.__doc__ = """ A Binary Tree (BT) with a node value, and left and rightsubtrees. """ def kdtree(points): """Construct a kd tree from an iterable of points. This algorithm is taken from Wikipedia. For more details, > https://en.wikipedia.org/wiki/Kd_tree#Construction """ k = len(points[0]) def build(*, points, depth): """Build a kd tree from a set of points at a given depth. """ if len(points) == 0: return None points.sort(key=operator.itemgetter(depth % k)) middle = len(points) // 2 return BT( value = points[middle], left = build( points=points[:middle], depth=depth+1, ), right = build( points=points[middle+1:], depth=depth+1, ), ) return build(points=list(points), depth=0) reference_points = [ (1, 2), (3, 2), (4, 1), (3, 5) ] kdtree(reference_points)
BT(value=(3, 5),
left=BT(value=(3, 2),
left=BT(value=(1, 2), left=None, right=None),
right=None),
right=BT(value=(4, 1), left=None, right=None))
What is the time complexity of constructing a kd tree from M points?

Python's timsort runs in O(M log M) time.

The construction recursively builds the left and rightsubtrees, which involves sorting two lists of half the size of the original.
2 O(½M log ½M) ≤ O(M log M)
As a result, each level of the tree takes O(M log M) time to build.

Because the list of points is halved at each level, there are O(log M) levels in the kd tree.
As a result, the overall time complexity of constructing the kd tree is
O(M [log M]^{2})
For the nearest neighbor search, I use an algorithm outlined on the kd tree Wikipedia page. This algorithm is a variation on searching a binary search tree.
Here is a Python function, find_nearest_neighbor
, that implements this search
algorithm:
NNRecord = collections.namedtuple("NNRecord", ["point", "distance"]) NNRecord.__doc__ = """ Used to keep track of the current best guess during a nearest neighbor search. """ def find_nearest_neighbor(*, tree, point): """Find the nearest neighbor in a kd tree for a given point. """ k = len(point) best = None def search(*, tree, depth): """Recursively search through the kd tree to find the nearest neighbor. """ nonlocal best if tree is None: return distance = SED(tree.value, point) if best is None or distance < best.distance: best = NNRecord(point=tree.value, distance=distance) axis = depth % k diff = point[axis]  tree.value[axis] if diff <= 0: close, away = tree.left, tree.right else: close, away = tree.right, tree.left search(tree=close, depth=depth+1) if diff**2 < best.distance: search(tree=away, depth=depth+1) search(tree=tree, depth=0) return best.point reference_points = [ (1, 2), (3, 2), (4, 1), (3, 5) ] tree = kdtree(reference_points) find_nearest_neighbor(tree=tree, point=(10, 1))
(4, 1)
The average time complexity of a nearest neighbor search in a balanced kd tree of M points is:
O(log M)
I combine kdtree
and find_nearest_neighbor
to create a new solution to the
"Nearest Neighbor Problem":
def nearest_neighbor_kdtree(*, query_points, reference_points): """Use a kd tree to solve the "Nearest Neighbor Problem".""" tree = kdtree(reference_points) return { query_p: find_nearest_neighbor(tree=tree, point=query_p) for query_p in query_points } reference_points = [ (1, 2), (3, 2), (4, 1), (3, 5) ] query_points = [ (3, 4), (5, 1), (7, 3), (8, 9), (10, 1), (3, 3) ] nearest_neighbor_kdtree( reference_points = reference_points, query_points = query_points, )
{(3, 4): (3, 5),
(5, 1): (4, 1),
(7, 3): (4, 1),
(8, 9): (3, 5),
(10, 1): (4, 1),
(3, 3): (3, 2)}
nn_kdtree = nearest_neighbor_kdtree( reference_points = reference_points, query_points = query_points, ) nn_bf = nearest_neighbor_bf( reference_points = reference_points, query_points = query_points, ) nn_kdtree == nn_bf
True
What is the time complexity of this algorithm? For N query points and M reference points:
 Building the kd tree takes O(M [log M]^{2}]) time.
 Each nearest neighbor search takes O(log M) time.
 O(N) nearest neighbor searches are conducted.
As a result, the overall time complexity of nearest_neighbor_kdtree
is
O(M [log M]^{2} + N log M)
This seems faster than the brute force algorithm, but it is hard for me to think of simple examples of this. Instead, I will give empirical measurements.
Confirming that the kd tree algorithm and the brute force algorithm give the same results
When working with complex algorithms like this, it's easy for me to make a mistake. To reduce my anxiety that I made a mistake with the kd tree algorithm, I will generate test data and make sure that the results match the results of the brute force algorithm:
import random random_point = lambda: (random.random(), random.random()) reference_points = [ random_point() for _ in range(3000) ] query_points = [ random_point() for _ in range(3000) ] solution_bf = nearest_neighbor_bf( reference_points = reference_points, query_points = query_points ) solution_kdtree = nearest_neighbor_kdtree( reference_points = reference_points, query_points = query_points ) solution_bf == solution_kdtree
True
Before this test, I was already sure that the brute force algorithm is correct. Running this test makes me confident that my kd tree algorithm is also correct.
Comparing the measured speed of the kd tree algorithm and the brute force algorithm
I generate some test data and use the cProfile module to measure:
import cProfile reference_points = [ random_point() for _ in range(4000) ] query_points = [ random_point() for _ in range(4000) ] cProfile.run(""" nearest_neighbor_bf( reference_points=reference_points, query_points=query_points, ) """)
96004005 function calls in 26.252 seconds
...
cProfile.run(""" nearest_neighbor_kdtree( reference_points=reference_points, query_points=query_points, ) """)
516215 function calls (422736 primitive calls) in 0.231 seconds
...
The brute force algorithm takes 26.252 seconds and the kd tree algorithm takes 0.231 seconds (over 100 times faster).
In conclusion...
In this week's post, you learned how to solve the "Nearest Neighbor Problem" efficiently using a kd tree (a kind of spatial index). Kd trees allow you to efficiently query large sets of spatial data to find close points. Kd trees and other spatial indices are used in databases to optimize queries.
My challenge to you:
I wrote a previous post, "Building a command line tool to simulate the spread of an infection", that showed you how to build a simulation of an infection spreading in a population.
Determining if a susceptible person gets infected requires finding the nearest infected person. The algorithm in that post used to find a close infected person is a brute force algorithm.
My challenge to you is to replace the brute force algorithm in the infection simulator with the kd tree algorithm that we built in this post.
If you enjoyed this week's post, share it with your friends and stay tuned for next week's post. See you then!