Solving the Nearest Neighbor Problem using Python
By John Lekberg on April 17, 2020.
This week's post is about solving the "Nearest Neighbor Problem". You will learn:
- How to use a brute force algorithm to solve the problem.
- How to use a spatial index to create a faster solution.
Problem Statement
This problem deals with points in real coordinate space.
You are given a collection of "reference points", e.g.
[ (1, 2), (3, 2), (4, 1), (3, 5) ]
And you are a given a collection of "query points", e.g.
[ (3, 4), (5, 1), (7, 3), (8, 9), (10, 1), (3, 3) ]
You goal is to find the nearest reference point for every query point. E.g.
For query point (3, 4), the closest reference point is (3, 5).
How I represent the data in the problem
I represent a point in real coordinate space as a tuple object. E.g.
(3, 4)
To compare distances (to find the nearest point) I use squared Euclidean distance (SED):
def SED(X, Y): """Compute the squared Euclidean distance between X and Y.""" return sum((i-j)**2 for i, j in zip(X, Y)) SED( (3, 4), (4, 9) )
26
SED ranks distances the same as Euclidean distance, so it is acceptable to use SED to find a "nearest neighbor".
I represent the solution as a dictionary object that maps query points to the nearest reference point. E.g.
{ (5, 1): (4, 1) }
Creating a brute force solution
A brute force solution to the "Nearest Neighbor Problem" will, for each query point, measure the distance (using SED) to every reference point and select the closest reference point:
def nearest_neighbor_bf(*, query_points, reference_points): """Use a brute force algorithm to solve the "Nearest Neighbor Problem". """ return { query_p: min( reference_points, key=lambda X: SED(X, query_p), ) for query_p in query_points } reference_points = [ (1, 2), (3, 2), (4, 1), (3, 5) ] query_points = [ (3, 4), (5, 1), (7, 3), (8, 9), (10, 1), (3, 3) ] nearest_neighbor_bf( reference_points = reference_points, query_points = query_points, )
{(3, 4): (3, 5),
(5, 1): (4, 1),
(7, 3): (4, 1),
(8, 9): (3, 5),
(10, 1): (4, 1),
(3, 3): (3, 2)}
What is the time complexity of this solution? For N query points and M reference points:
- Finding the closest reference point for a given query point takes O(M) steps.
- The algorithm has to find the closest reference point for O(N) query points.
As a result, the overall time complexity of the brute force algorithm is
O(N M).
How could I solve this problem faster?
- I always iterate over O(N) query points, so I cannot reduce the N factor.
- However, maybe I can find the nearest reference point is less than O(M) steps (faster than checking every reference point).
Using a spatial index to create a faster solution
A spatial index is a data structure used to optimize spacial queries. E.g.
- What is the nearest reference point for a query point?
- What reference points are within a 1 meter radius of a query point?
A k-dimensional tree (k-d tree) is a spatial index that uses a binary tree to divide up real coordinate space. Why should I use a k-d tree to solve the "Nearest Neighbor Problem"?
- For M reference points, searching for the nearest neighbor of a query point takes, on average, O(log M) time. This is faster than the O(M) time of the brute force algorithm.
For more information about k-d trees, read these documents:
- "Using k-d trees to efficiently calculate nearest neighbors in 3D vector space" by Jeremy Day
- "k-d tree" on Wikipedia
- "k-d tree", Dictionary of Algorithms and Data Structures, by Paul E. Black
- "kd-Trees, CMSC 420" by Carl Kingsford
I construct a balanced k-d tree using an algorithm outlined on the k-d tree Wikipedia page. This is a recursive algorithm that divides the set of reference points in half at the median and recursively builds left- and right-subtrees.
Here's a Python function, kdtree
, that implements the construction algorithm:
import collections import operator BT = collections.namedtuple("BT", ["value", "left", "right"]) BT.__doc__ = """ A Binary Tree (BT) with a node value, and left- and right-subtrees. """ def kdtree(points): """Construct a k-d tree from an iterable of points. This algorithm is taken from Wikipedia. For more details, > https://en.wikipedia.org/wiki/K-d_tree#Construction """ k = len(points[0]) def build(*, points, depth): """Build a k-d tree from a set of points at a given depth. """ if len(points) == 0: return None points.sort(key=operator.itemgetter(depth % k)) middle = len(points) // 2 return BT( value = points[middle], left = build( points=points[:middle], depth=depth+1, ), right = build( points=points[middle+1:], depth=depth+1, ), ) return build(points=list(points), depth=0) reference_points = [ (1, 2), (3, 2), (4, 1), (3, 5) ] kdtree(reference_points)
BT(value=(3, 5),
left=BT(value=(3, 2),
left=BT(value=(1, 2), left=None, right=None),
right=None),
right=BT(value=(4, 1), left=None, right=None))
What is the time complexity of constructing a k-d tree from M points?
-
Python's timsort runs in O(M log M) time.
-
The construction recursively builds the left- and right-subtrees, which involves sorting two lists of half the size of the original.
2 O(½M log ½M) ≤ O(M log M)
As a result, each level of the tree takes O(M log M) time to build.
-
Because the list of points is halved at each level, there are O(log M) levels in the k-d tree.
As a result, the overall time complexity of constructing the k-d tree is
O(M [log M]2)
For the nearest neighbor search, I use an algorithm outlined on the k-d tree Wikipedia page. This algorithm is a variation on searching a binary search tree.
Here is a Python function, find_nearest_neighbor
, that implements this search
algorithm:
NNRecord = collections.namedtuple("NNRecord", ["point", "distance"]) NNRecord.__doc__ = """ Used to keep track of the current best guess during a nearest neighbor search. """ def find_nearest_neighbor(*, tree, point): """Find the nearest neighbor in a k-d tree for a given point. """ k = len(point) best = None def search(*, tree, depth): """Recursively search through the k-d tree to find the nearest neighbor. """ nonlocal best if tree is None: return distance = SED(tree.value, point) if best is None or distance < best.distance: best = NNRecord(point=tree.value, distance=distance) axis = depth % k diff = point[axis] - tree.value[axis] if diff <= 0: close, away = tree.left, tree.right else: close, away = tree.right, tree.left search(tree=close, depth=depth+1) if diff**2 < best.distance: search(tree=away, depth=depth+1) search(tree=tree, depth=0) return best.point reference_points = [ (1, 2), (3, 2), (4, 1), (3, 5) ] tree = kdtree(reference_points) find_nearest_neighbor(tree=tree, point=(10, 1))
(4, 1)
The average time complexity of a nearest neighbor search in a balanced k-d tree of M points is:
O(log M)
I combine kdtree
and find_nearest_neighbor
to create a new solution to the
"Nearest Neighbor Problem":
def nearest_neighbor_kdtree(*, query_points, reference_points): """Use a k-d tree to solve the "Nearest Neighbor Problem".""" tree = kdtree(reference_points) return { query_p: find_nearest_neighbor(tree=tree, point=query_p) for query_p in query_points } reference_points = [ (1, 2), (3, 2), (4, 1), (3, 5) ] query_points = [ (3, 4), (5, 1), (7, 3), (8, 9), (10, 1), (3, 3) ] nearest_neighbor_kdtree( reference_points = reference_points, query_points = query_points, )
{(3, 4): (3, 5),
(5, 1): (4, 1),
(7, 3): (4, 1),
(8, 9): (3, 5),
(10, 1): (4, 1),
(3, 3): (3, 2)}
nn_kdtree = nearest_neighbor_kdtree( reference_points = reference_points, query_points = query_points, ) nn_bf = nearest_neighbor_bf( reference_points = reference_points, query_points = query_points, ) nn_kdtree == nn_bf
True
What is the time complexity of this algorithm? For N query points and M reference points:
- Building the k-d tree takes O(M [log M]2]) time.
- Each nearest neighbor search takes O(log M) time.
- O(N) nearest neighbor searches are conducted.
As a result, the overall time complexity of nearest_neighbor_kdtree
is
O(M [log M]2 + N log M)
This seems faster than the brute force algorithm, but it is hard for me to think of simple examples of this. Instead, I will give empirical measurements.
Confirming that the k-d tree algorithm and the brute force algorithm give the same results
When working with complex algorithms like this, it's easy for me to make a mistake. To reduce my anxiety that I made a mistake with the k-d tree algorithm, I will generate test data and make sure that the results match the results of the brute force algorithm:
import random random_point = lambda: (random.random(), random.random()) reference_points = [ random_point() for _ in range(3000) ] query_points = [ random_point() for _ in range(3000) ] solution_bf = nearest_neighbor_bf( reference_points = reference_points, query_points = query_points ) solution_kdtree = nearest_neighbor_kdtree( reference_points = reference_points, query_points = query_points ) solution_bf == solution_kdtree
True
Before this test, I was already sure that the brute force algorithm is correct. Running this test makes me confident that my k-d tree algorithm is also correct.
Comparing the measured speed of the k-d tree algorithm and the brute force algorithm
I generate some test data and use the cProfile module to measure:
import cProfile reference_points = [ random_point() for _ in range(4000) ] query_points = [ random_point() for _ in range(4000) ] cProfile.run(""" nearest_neighbor_bf( reference_points=reference_points, query_points=query_points, ) """)
96004005 function calls in 26.252 seconds
...
cProfile.run(""" nearest_neighbor_kdtree( reference_points=reference_points, query_points=query_points, ) """)
516215 function calls (422736 primitive calls) in 0.231 seconds
...
The brute force algorithm takes 26.252 seconds and the k-d tree algorithm takes 0.231 seconds (over 100 times faster).
In conclusion...
In this week's post, you learned how to solve the "Nearest Neighbor Problem" efficiently using a k-d tree (a kind of spatial index). K-d trees allow you to efficiently query large sets of spatial data to find close points. K-d trees and other spatial indices are used in databases to optimize queries.
My challenge to you:
I wrote a previous post, "Building a command line tool to simulate the spread of an infection", that showed you how to build a simulation of an infection spreading in a population.
Determining if a susceptible person gets infected requires finding the nearest infected person. The algorithm in that post used to find a close infected person is a brute force algorithm.
My challenge to you is to replace the brute force algorithm in the infection simulator with the k-d tree algorithm that we built in this post.
If you enjoyed this week's post, share it with your friends and stay tuned for next week's post. See you then!
(If you spot any errors or typos on this post, contact me via my contact page.)