Return to Blog

Solving the Nearest Neighbor Problem using Python

By John Lekberg on April 17, 2020.


This week's post is about solving the "Nearest Neighbor Problem". You will learn:

Problem Statement

This problem deals with points in real coordinate space.

You are given a collection of "reference points", e.g.

[ (1, 2), (3, 2), (4, 1), (3, 5) ]

And you are a given a collection of "query points", e.g.

[ (3, 4), (5, 1), (7, 3), (8, 9), (10, 1), (3, 3) ]

You goal is to find the nearest reference point for every query point. E.g.

For query point (3, 4), the closest reference point is (3, 5).

How I represent the data in the problem

I represent a point in real coordinate space as a tuple object. E.g.

(3, 4)

To compare distances (to find the nearest point) I use squared Euclidean distance (SED):

def SED(X, Y):
    """Compute the squared Euclidean distance between X and Y."""
    return sum((i-j)**2 for i, j in zip(X, Y))

SED( (3, 4), (4, 9) )
26

SED ranks distances the same as Euclidean distance, so it is acceptable to use SED to find a "nearest neighbor".

I represent the solution as a dictionary object that maps query points to the nearest reference point. E.g.

{ (5, 1): (4, 1) }

Creating a brute force solution

A brute force solution to the "Nearest Neighbor Problem" will, for each query point, measure the distance (using SED) to every reference point and select the closest reference point:

def nearest_neighbor_bf(*, query_points, reference_points):
    """Use a brute force algorithm to solve the
    "Nearest Neighbor Problem".
    """
    return {
        query_p: min(
            reference_points,
            key=lambda X: SED(X, query_p),
        )
        for query_p in query_points
    }

reference_points = [ (1, 2), (3, 2), (4, 1), (3, 5) ]
query_points = [
    (3, 4), (5, 1), (7, 3), (8, 9), (10, 1), (3, 3)
]

nearest_neighbor_bf(
    reference_points = reference_points,
    query_points = query_points,
)
{(3, 4): (3, 5),
 (5, 1): (4, 1),
 (7, 3): (4, 1),
 (8, 9): (3, 5),
 (10, 1): (4, 1),
 (3, 3): (3, 2)}

What is the time complexity of this solution? For N query points and M reference points:

As a result, the overall time complexity of the brute force algorithm is

O(N M).

How could I solve this problem faster?

Using a spatial index to create a faster solution

A spatial index is a data structure used to optimize spacial queries. E.g.

A k-dimensional tree (k-d tree) is a spatial index that uses a binary tree to divide up real coordinate space. Why should I use a k-d tree to solve the "Nearest Neighbor Problem"?

For more information about k-d trees, read these documents:

I construct a balanced k-d tree using an algorithm outlined on the k-d tree Wikipedia page. This is a recursive algorithm that divides the set of reference points in half at the median and recursively builds left- and right-subtrees.

Here's a Python function, kdtree, that implements the construction algorithm:

import collections
import operator

BT = collections.namedtuple("BT", ["value", "left", "right"])
BT.__doc__ = """
A Binary Tree (BT) with a node value, and left- and
right-subtrees.
"""

def kdtree(points):
    """Construct a k-d tree from an iterable of points.
    
    This algorithm is taken from Wikipedia. For more details,
    
    > https://en.wikipedia.org/wiki/K-d_tree#Construction
    
    """
    k = len(points[0])
    
    def build(*, points, depth):
        """Build a k-d tree from a set of points at a given
        depth.
        """
        if len(points) == 0:
            return None
        
        points.sort(key=operator.itemgetter(depth % k))
        middle = len(points) // 2
        
        return BT(
            value = points[middle],
            left = build(
                points=points[:middle],
                depth=depth+1,
            ),
            right = build(
                points=points[middle+1:],
                depth=depth+1,
            ),
        )
    
    return build(points=list(points), depth=0)

reference_points = [ (1, 2), (3, 2), (4, 1), (3, 5) ]
kdtree(reference_points)
BT(value=(3, 5),
   left=BT(value=(3, 2),
           left=BT(value=(1, 2), left=None, right=None),
	   right=None),
   right=BT(value=(4, 1), left=None, right=None))

What is the time complexity of constructing a k-d tree from M points?

As a result, the overall time complexity of constructing the k-d tree is

O(M [log M]2)

For the nearest neighbor search, I use an algorithm outlined on the k-d tree Wikipedia page. This algorithm is a variation on searching a binary search tree.

Here is a Python function, find_nearest_neighbor, that implements this search algorithm:

NNRecord = collections.namedtuple("NNRecord", ["point", "distance"])
NNRecord.__doc__ = """
Used to keep track of the current best guess during a nearest
neighbor search.
"""

def find_nearest_neighbor(*, tree, point):
    """Find the nearest neighbor in a k-d tree for a given
    point.
    """
    k = len(point)
    
    best = None
    def search(*, tree, depth):
        """Recursively search through the k-d tree to find the
        nearest neighbor.
        """
        nonlocal best
        
        if tree is None:
            return
        
        distance = SED(tree.value, point)
        if best is None or distance < best.distance:
            best = NNRecord(point=tree.value, distance=distance)
        
        axis = depth % k
        diff = point[axis] - tree.value[axis]
        if diff <= 0:
            close, away = tree.left, tree.right
        else:
            close, away = tree.right, tree.left
        
        search(tree=close, depth=depth+1)
        if diff**2 < best.distance:
            search(tree=away, depth=depth+1)
    
    search(tree=tree, depth=0)
    return best.point

reference_points = [ (1, 2), (3, 2), (4, 1), (3, 5) ]
tree = kdtree(reference_points)
find_nearest_neighbor(tree=tree, point=(10, 1))
(4, 1)

The average time complexity of a nearest neighbor search in a balanced k-d tree of M points is:

O(log M)

I combine kdtree and find_nearest_neighbor to create a new solution to the "Nearest Neighbor Problem":

def nearest_neighbor_kdtree(*, query_points, reference_points):
    """Use a k-d tree to solve the "Nearest Neighbor Problem"."""
    tree = kdtree(reference_points)
    return {
        query_p: find_nearest_neighbor(tree=tree, point=query_p)
        for query_p in query_points
    }

reference_points = [ (1, 2), (3, 2), (4, 1), (3, 5) ]
query_points = [
    (3, 4), (5, 1), (7, 3), (8, 9), (10, 1), (3, 3)
]

nearest_neighbor_kdtree(
    reference_points = reference_points,
    query_points = query_points,
)
{(3, 4): (3, 5),
 (5, 1): (4, 1),
 (7, 3): (4, 1),
 (8, 9): (3, 5),
 (10, 1): (4, 1),
 (3, 3): (3, 2)}
nn_kdtree = nearest_neighbor_kdtree(
    reference_points = reference_points,
    query_points = query_points,
)
nn_bf = nearest_neighbor_bf(
    reference_points = reference_points,
    query_points = query_points,
)
nn_kdtree == nn_bf
True

What is the time complexity of this algorithm? For N query points and M reference points:

As a result, the overall time complexity of nearest_neighbor_kdtree is

O(M [log M]2 + N log M)

This seems faster than the brute force algorithm, but it is hard for me to think of simple examples of this. Instead, I will give empirical measurements.

Confirming that the k-d tree algorithm and the brute force algorithm give the same results

When working with complex algorithms like this, it's easy for me to make a mistake. To reduce my anxiety that I made a mistake with the k-d tree algorithm, I will generate test data and make sure that the results match the results of the brute force algorithm:

import random

random_point = lambda: (random.random(), random.random())
reference_points = [ random_point() for _ in range(3000) ]
query_points = [ random_point() for _ in range(3000) ]

solution_bf = nearest_neighbor_bf(
    reference_points = reference_points,
    query_points = query_points
)
solution_kdtree = nearest_neighbor_kdtree(
    reference_points = reference_points,
    query_points = query_points
)

solution_bf == solution_kdtree
True

Before this test, I was already sure that the brute force algorithm is correct. Running this test makes me confident that my k-d tree algorithm is also correct.

Comparing the measured speed of the k-d tree algorithm and the brute force algorithm

I generate some test data and use the cProfile module to measure:

import cProfile

reference_points = [ random_point() for _ in range(4000) ]
query_points = [ random_point() for _ in range(4000) ]

cProfile.run("""
nearest_neighbor_bf(
    reference_points=reference_points,
    query_points=query_points,
)
""")
96004005 function calls in 26.252 seconds

...
cProfile.run("""
nearest_neighbor_kdtree(
    reference_points=reference_points,
    query_points=query_points,
)
""")
516215 function calls (422736 primitive calls) in 0.231 seconds

...

The brute force algorithm takes 26.252 seconds and the k-d tree algorithm takes 0.231 seconds (over 100 times faster).

In conclusion...

In this week's post, you learned how to solve the "Nearest Neighbor Problem" efficiently using a k-d tree (a kind of spatial index). K-d trees allow you to efficiently query large sets of spatial data to find close points. K-d trees and other spatial indices are used in databases to optimize queries.

My challenge to you:

I wrote a previous post, "Building a command line tool to simulate the spread of an infection", that showed you how to build a simulation of an infection spreading in a population.

Determining if a susceptible person gets infected requires finding the nearest infected person. The algorithm in that post used to find a close infected person is a brute force algorithm.

My challenge to you is to replace the brute force algorithm in the infection simulator with the k-d tree algorithm that we built in this post.

If you enjoyed this week's post, share it with your friends and stay tuned for next week's post. See you then!


(If you spot any errors or typos on this post, contact me via my contact page.)