Return to Blog

Solving the Nearest Neighbor Problem using Python

By John Lekberg on April 17, 2020.


This week's post is about solving the "Nearest Neighbor Problem". You will learn:

Problem Statement

This problem deals with points in real coordinate space.

You are given a collection of "reference points", e.g.

[ (1, 2), (3, 2), (4, 1), (3, 5) ]

And you are a given a collection of "query points", e.g.

[ (3, 4), (5, 1), (7, 3), (8, 9), (10, 1), (3, 3) ]

You goal is to find the nearest reference point for every query point. E.g.

For query point (3, 4), the closest reference point is (3, 5).

How I represent the data in the problem

I represent a point in real coordinate space as a tuple object. E.g.

(3, 4)

To compare distances (to find the nearest point) I use squared Euclidean distance (SED):

def SED(X, Y):
    """Compute the squared Euclidean distance between X and Y."""
    return sum((i-j)**2 for i, j in zip(X, Y))

SED( (3, 4), (4, 9) )
26

SED ranks distances the same as Euclidean distance, so it is acceptable to use SED to find a "nearest neighbor".

I represent the solution as a dictionary object that maps query points to the nearest reference point. E.g.

{ (5, 1): (4, 1) }

Creating a brute force solution

A brute force solution to the "Nearest Neighbor Problem" will, for each query point, measure the distance (using SED) to every reference point and select the closest reference point:

def nearest_neighbor_bf(*, query_points, reference_points):
    """Use a brute force algorithm to solve the
    "Nearest Neighbor Problem".
    """
    return {
        query_p: min(
            reference_points,
            key=lambda X: SED(X, query_p),
        )
        for query_p in query_points
    }

reference_points = [ (1, 2), (3, 2), (4, 1), (3, 5) ]
query_points = [
    (3, 4), (5, 1), (7, 3), (8, 9), (10, 1), (3, 3)
]

nearest_neighbor_bf(
    reference_points = reference_points,
    query_points = query_points,
)
{(3, 4): (3, 5),
 (5, 1): (4, 1),
 (7, 3): (4, 1),
 (8, 9): (3, 5),
 (10, 1): (4, 1),
 (3, 3): (3, 2)}

What is the time complexity of this solution? For N query points and M reference points:

As a result, the overall time complexity of the brute force algorithm is

O(N M).

How could I solve this problem faster?

Using a spatial index to create a faster solution

A spatial index is a data structure used to optimize spacial queries. E.g.

A k-dimensional tree (k-d tree) is a spatial index that uses a binary tree to divide up real coordinate space. Why should I use a k-d tree to solve the "Nearest Neighbor Problem"?

For more information about k-d trees, read these documents:

I construct a balanced k-d tree using an algorithm outlined on the k-d tree Wikipedia page. This is a recursive algorithm that divides the set of reference points in half at the median and recursively builds left- and right-subtrees.

Here's a Python function, kdtree, that implements the construction algorithm:

import collections
import operator

BT = collections.namedtuple("BT", ["value", "left", "right"])
BT.__doc__ = """
A Binary Tree (BT) with a node value, and left- and
right-subtrees.
"""

def kdtree(points):
    """Construct a k-d tree from an iterable of points.
    
    This algorithm is taken from Wikipedia. For more details,
    
    > https://en.wikipedia.org/wiki/K-d_tree#Construction
    
    """
    k = len(points[0])
    
    def build(*, points, depth):
        """Build a k-d tree from a set of points at a given
        depth.
        """
        if len(points) == 0:
            return None
        
        points.sort(key=operator.itemgetter(depth % k))
        middle = len(points) // 2
        
        return BT(
            value = points[middle],
            left = build(
                points=points[:middle],
                depth=depth+1,
            ),
            right = build(
                points=points[middle+1:],
                depth=depth+1,
            ),
        )
    
    return build(points=list(points), depth=0)

reference_points = [ (1, 2), (3, 2), (4, 1), (3, 5) ]
kdtree(reference_points)
BT(value=(3, 5),
   left=BT(value=(3, 2),
           left=BT(value=(1, 2), left=None, right=None),
	   right=None),
   right=BT(value=(4, 1), left=None, right=None))

What is the time complexity of constructing a k-d tree from M points?

As a result, the overall time complexity of constructing the k-d tree is

O(M [log M]2)

For the nearest neighbor search, I use an algorithm outlined on the k-d tree Wikipedia page. This algorithm is a variation on searching a binary search tree.

Here is a Python function, find_nearest_neighbor, that implements this search algorithm:

NNRecord = collections.namedtuple("NNRecord", ["point", "distance"])
NNRecord.__doc__ = """
Used to keep track of the current best guess during a nearest
neighbor search.
"""

def find_nearest_neighbor(*, tree, point):
    """Find the nearest neighbor in a k-d tree for a given
    point.
    """
    k = len(point)
    
    best = None
    def search(*, tree, depth):
        """Recursively search through the k-d tree to find the
        nearest neighbor.
        """
        nonlocal best
        
        if tree is None:
            return
        
        distance = SED(tree.value, point)
        if best is None or distance < best.distance:
            best = NNRecord(point=tree.value, distance=distance)
        
        axis = depth % k
        diff = point[axis] - tree.value[axis]
        if diff <= 0:
            close, away = tree.left, tree.right
        else:
            close, away = tree.right, tree.left
        
        search(tree=close, depth=depth+1)
        if diff**2 < best.distance:
            search(tree=away, depth=depth+1)
    
    search(tree=tree, depth=0)
    return best.point

reference_points = [ (1, 2), (3, 2), (4, 1), (3, 5) ]
tree = kdtree(reference_points)
find_nearest_neighbor(tree=tree, point=(10, 1))
(4, 1)

The average time complexity of a nearest neighbor search in a balanced k-d tree of M points is:

O(log M)

I combine kdtree and find_nearest_neighbor to create a new solution to the "Nearest Neighbor Problem":

def nearest_neighbor_kdtree(*, query_points, reference_points):
    """Use a k-d tree to solve the "Nearest Neighbor Problem"."""
    tree = kdtree(reference_points)
    return {
        query_p: find_nearest_neighbor(tree=tree, point=query_p)
        for query_p in query_points
    }

reference_points = [ (1, 2), (3, 2), (4, 1), (3, 5) ]
query_points = [
    (3, 4), (5, 1), (7, 3), (8, 9), (10, 1), (3, 3)
]

nearest_neighbor_kdtree(
    reference_points = reference_points,
    query_points = query_points,
)
{(3, 4): (3, 5),
 (5, 1): (4, 1),
 (7, 3): (4, 1),
 (8, 9): (3, 5),
 (10, 1): (4, 1),
 (3, 3): (3, 2)}
nn_kdtree = nearest_neighbor_kdtree(
    reference_points = reference_points,
    query_points = query_points,
)
nn_bf = nearest_neighbor_bf(
    reference_points = reference_points,
    query_points = query_points,
)
nn_kdtree == nn_bf
True

What is the time complexity of this algorithm? For N query points and M reference points:

As a result, the overall time complexity of nearest_neighbor_kdtree is

O(M [log M]2 + N log M)

This seems faster than the brute force algorithm, but it is hard for me to think of simple examples of this. Instead, I will give empirical measurements.

Confirming that the k-d tree algorithm and the brute force algorithm give the same results

When working with complex algorithms like this, it's easy for me to make a mistake. To reduce my anxiety that I made a mistake with the k-d tree algorithm, I will generate test data and make sure that the results match the results of the brute force algorithm:

import random

random_point = lambda: (random.random(), random.random())
reference_points = [ random_point() for _ in range(3000) ]
query_points = [ random_point() for _ in range(3000) ]

solution_bf = nearest_neighbor_bf(
    reference_points = reference_points,
    query_points = query_points
)
solution_kdtree = nearest_neighbor_kdtree(
    reference_points = reference_points,
    query_points = query_points
)

solution_bf == solution_kdtree
True

Before this test, I was already sure that the brute force algorithm is correct. Running this test makes me confident that my k-d tree algorithm is also correct.

Comparing the measured speed of the k-d tree algorithm and the brute force algorithm

I generate some test data and use the cProfile module to measure:

import cProfile

reference_points = [ random_point() for _ in range(4000) ]
query_points = [ random_point() for _ in range(4000) ]

cProfile.run("""
nearest_neighbor_bf(
    reference_points=reference_points,
    query_points=query_points,
)
""")
96004005 function calls in 26.252 seconds

...
cProfile.run("""
nearest_neighbor_kdtree(
    reference_points=reference_points,
    query_points=query_points,
)
""")
516215 function calls (422736 primitive calls) in 0.231 seconds

...

The brute force algorithm takes 26.252 seconds and the k-d tree algorithm takes 0.231 seconds (over 100 times faster).

In conclusion...

In this week's post, you learned how to solve the "Nearest Neighbor Problem" efficiently using a k-d tree (a kind of spatial index). K-d trees allow you to efficiently query large sets of spatial data to find close points. K-d trees and other spatial indices are used in databases to optimize queries.

My challenge to you:

I wrote a previous post, "Building a command line tool to simulate the spread of an infection", that showed you how to build a simulation of an infection spreading in a population.

Determining if a susceptible person gets infected requires finding the nearest infected person. The algorithm in that post used to find a close infected person is a brute force algorithm.

My challenge to you is to replace the brute force algorithm in the infection simulator with the k-d tree algorithm that we built in this post.

If you enjoyed this week's post, share it with your friends and stay tuned for next week's post. See you then!