Return to Blog

Using key functions to sort data in Python

By John Lekberg on February 19, 2020.


This week's blog post is about key functions, which are used to sort data using more advanced comparisons, e.g. case-insensitive string sorting. You will learn:

What are key functions?

Key functions are functions that are used to compare objects without permanently transforming them. For example, I have a list of students

from collections import namedtuple

Student = namedtuple("Student", ["name", "age"])

students = [
    Student(name="John", age=23),
    Student(name="Kyle", age=19),
    Student(name="Barry", age=22),
    Student(name="Kai", age=24),
]

I want to sort the students by age. But because of the order of the namedtuple fields, the default sort will sort by name, then by age:

sorted(students)
[Student(name='Barry', age=22),
 Student(name='John', age=23),
 Student(name='Kai', age=24),
 Student(name='Kyle', age=19)]

So I use a key function, lambda student: student.age, that only compares students by age:

sorted(students, key=lambda student: student.age)
[Student(name='Kyle', age=19),
 Student(name='Barry', age=22),
 Student(name='John', age=23),
 Student(name='Kai', age=24)]

How would I sort the students by age without using key functions? I would use a technique known as the Schwartzian transform (also known as "decorate-sort-undecorate"):

[
    student
    for _, _, student in sorted(
        (student.age, idx, student)
        for idx, student in enumerate(students)
    )
]
[Student(name='Kyle', age=19),
 Student(name='Barry', age=22),
 Student(name='John', age=23),
 Student(name='Kai', age=24)]

Key functions are more succint and more readable than Schwartzian transforms, so I recommend sticking with key functions.

How does the standard library use key functions?

The standard library uses key functions for sorting in three places:

Let's say I have a list of students:

from collections import namedtuple

Student = namedtuple("Student", ["name", "age"])

students = [
    Student(name="John", age=23),
    Student(name="Kyle", age=19),
    Student(name="Barry", age=22),
    Student(name="Kai", age=24),
]

Using sorted to sort the students by age and create a new list:

sorted(students, key=lambda student: student.age)
[Student(name='Kyle', age=19),
 Student(name='Barry', age=22),
 Student(name='John', age=23),
 Student(name='Kai', age=24)]

Using list.sort to sort the list of students by age, sorting in-place:

students.sort(key=lambda student: student.age)
students
[Student(name='Kyle', age=19),
 Student(name='Barry', age=22),
 Student(name='John', age=23),
 Student(name='Kai', age=24)]

Using max to get the oldest student:

max(students, key=lambda student: student.age)
Student(name='Kai', age=24)

Using min to get the youngest student:

min(students, key=lambda student: student.age)
Student(name='Kyle', age=19)

Using heapq.nlargest to get the two oldest students:

from heapq import nlargest
nlargest(2, students, key=lambda student: student.age)
[Student(name='Kai', age=24),
 Student(name='John', age=23)]

Using heapq.nsmallest to get the two youngest students:

from heapq import nsmallest
nsmallest(2, students, key=lambda student: student.age)
[Student(name='Kyle', age=19),
 Student(name='Barry', age=22)]

Here's an example of using heapq.merge. I have a list of male students sorted by age:

age = lambda student: student.age

male_students = [
    Student(name="John", age=23),
    Student(name="Kyle", age=19),
    Student(name="Barry", age=22),
    Student(name="Kai", age=24),
]
male_students.sort(key=age)
[Student(name='Kyle', age=19),
 Student(name='Barry', age=22),
 Student(name='John', age=23),
 Student(name='Kai', age=24)]

And I have a list of female students sorted by age:

female_students = [
    Student(name="Susan", age=19),
    Student(name="Emily", age=24),
    Student(name="Caroline", age=25),
    Student(name="Maru", age=23),
]
female_students.sort(key=age)
[Student(name='Susan', age=19),
 Student(name='Maru', age=23),
 Student(name='Emily', age=24),
 Student(name='Caroline', age=25)]

I use heapq.merge to get a list of all students, sorted by age:

from heapq import merge

list(merge(male_students, female_students, key=age))
[Student(name='Kyle', age=19),
 Student(name='Susan', age=19),
 Student(name='Barry', age=22),
 Student(name='John', age=23),
 Student(name='Maru', age=23),
 Student(name='Kai', age=24),
 Student(name='Emily', age=24),
 Student(name='Caroline', age=25)]

Why is heapq.merge better in this case than combining the lists and sorting that? heapq.merge is more efficient because it takes advantage of the fact that the input lists are already sorted.

Using the operator module to easily create key functions

Key functions can be any function, but three common patterns are:

The operator module provides three tools that address these common patterns:

I have a list of fruits that I want to organize:

from collections import namedtuple
from datetime import date

Fruit = namedtuple("Fruit", ["name", "price", "sell_by"])
fruits = [
    Fruit(name="Banana", price=10, sell_by=date(2020, 2, 17)),
    Fruit(name="Strawberry", price=8, sell_by=date(2020, 2, 1)),
    Fruit(name="Kiwi", price=8, sell_by=date(2020, 6, 3)),
    Fruit(name="Starfruit", price=12, sell_by=date(2020, 3, 4)),
    Fruit(name="Orange", price=5, sell_by=date(2020, 2, 28)),
]

I use attrgetter to sort the fruits by price, then by name:

from operator import attrgetter
sorted(fruits, key=attrgetter("price", "name"))
[Fruit(name='Orange', price=5, sell_by=datetime.date(2020, 2, 28)),
 Fruit(name='Kiwi', price=8, sell_by=datetime.date(2020, 6, 3)),
 Fruit(name='Strawberry', price=8, sell_by=datetime.date(2020, 2, 1)),
 Fruit(name='Banana', price=10, sell_by=datetime.date(2020, 2, 17)),
 Fruit(name='Starfruit', price=12, sell_by=datetime.date(2020, 3, 4))]

The equivalent lambda for

attrgetter("price", "name")

is

lambda fruit: (fruit.price, fruit.name)

attrgetter can even handle nested attributes (e.g. fruit.sell_by.month), which means that I can organize fruits by which day of the month they should sell by:

sorted(fruits, key=attrgetter("sell_by.day"))
[Fruit(name='Strawberry', price=8, sell_by=datetime.date(2020, 2, 1)),
 Fruit(name='Kiwi', price=8, sell_by=datetime.date(2020, 6, 3)),
 Fruit(name='Starfruit', price=12, sell_by=datetime.date(2020, 3, 4)),
 Fruit(name='Banana', price=10, sell_by=datetime.date(2020, 2, 17)),
 Fruit(name='Orange', price=5, sell_by=datetime.date(2020, 2, 28))]

The equivalent lambda for

attrgetter("sell_by.day")

is

lambda fruit: fruit.sell_by.day

itemgetter is similar to attrgetter. itemgetter uses object[key] notation, so it can be used for mappings and sequences. For example, if my list of fruits is a list of dictionaries instead of a list of namedtuples

fruit_dicts = [
 {"name": "Banana", "price": 10, "sell_by": date(2020, 2, 17)},
 {"name": "Strawberry", "price": 8, "sell_by": date(2020, 2, 1)},
 {"name": "Kiwi", "price": 8, "sell_by": date(2020, 6, 3)},
 {"name": "Starfruit", "price": 12, "sell_by": date(2020, 3, 4)},
 {"name": "Orange", "price": 5, "sell_by": date(2020, 2, 28)}
]

I use itemgetter to sort the fruits by price, then by name:

from operator import itemgetter

sorted(fruit_dicts, key=itemgetter("price", "name"))
[{'name': 'Orange', 'price': 5, 'sell_by': datetime.date(2020, 2, 28)},
 {'name': 'Kiwi', 'price': 8, 'sell_by': datetime.date(2020, 6, 3)},
 {'name': 'Strawberry', 'price': 8, 'sell_by': datetime.date(2020, 2, 1)},
 {'name': 'Banana', 'price': 10, 'sell_by': datetime.date(2020, 2, 17)},
 {'name': 'Starfruit', 'price': 12, 'sell_by': datetime.date(2020, 3, 4)}]

The equivalent lambda for

itemgetter("price", "name")

is

lambda fruit: (fruit["price"], fruit["name"])

itemgetter also works on sequences, like lists and tuples. For example, if my list of fruits is a list of tuples instead of namedtuples

fruit_tuples = [
 ("Banana", 10, date(2020, 2, 17)),
 ("Strawberry", 8, date(2020, 2, 1)),
 ("Kiwi", 8, date(2020, 6, 3)),
 ("Starfruit", 12, date(2020, 3, 4)),
 ("Orange", 5, date(2020, 2, 28))
]

I use itemgetter to sort the fruits by price, then by name:

sorted(fruit_tuples, key=itemgetter(1, 0))
[('Orange', 5, datetime.date(2020, 2, 28)),
 ('Kiwi', 8, datetime.date(2020, 6, 3)),
 ('Strawberry', 8, datetime.date(2020, 2, 1)),
 ('Banana', 10, datetime.date(2020, 2, 17)),
 ('Starfruit', 12, datetime.date(2020, 3, 4))]

The equivalent lambda for

itemgetter(1, 0)

is

lambda fruit: (fruit[1], fruit[0])

methodcaller is different than attrgetter and itemgetter. For example, I have a list of words:

words = ["king", "ant", "Kylie", "Aaron", "advark"]

I want to sort them in dictionary order, but the default sort puts upper case letters before lower case letters:

sorted(words)
['Aaron', 'Kylie', 'advark', 'ant', 'king']

I want to use str.casefold to do a case-insensitive comparison:

"AarON".casefold(), "aaROn".casefold()
('aaron', 'aaron')

I use methodcaller to compare strings using str.casefold:

from operator import methodcaller

sorted(words, key=methodcaller("casefold"))
['Aaron', 'advark', 'ant', 'king', 'Kylie']

The equivalent lambda for:

methodcaller("casefold")

is

lambda word: word.casefold()

methodcaller also supports methods which take arguments. For example, I have a list of DNA sequences:

dna_sequences = [
    "AACCTGCGGAAGGATCATTACC",
    "GAGTGCGGGTCCTTTGGGCCCA",
    "ACCTCCCATCCGTGTCTATTGT",
    "TGTTGCTTCGGCGGGCCCGCCG",
]

I can count how often "A" (Adenine) appears in a sequence using str.count:

"AACCTGCGGAAGGATCATTACC".count("A")
7

I want to order the DNA sequences by how much adenine they contain, so I use methodcaller:

sorted(dna_sequences, key=methodcaller("count", "A"))
['TGTTGCTTCGGCGGGCCCGCCG',
 'GAGTGCGGGTCCTTTGGGCCCA',
 'ACCTCCCATCCGTGTCTATTGT',
 'AACCTGCGGAAGGATCATTACC']

The equivalent lambda for

methodcaller("count", "A")

is

lambda dna_sequence: dna_sequence.count("A")

In conclusion...

Key functions are useful when you want to compare data on a specific trait, like sorting fruits by their price. Python's standard library supports key functions for

The operator module supports three easy ways to create key functions: attrgetter, itemgetter, and methodcaller.

My challenge to you is:

Given a list of fruits packages selling for different prices:

from collections import namedtuple

Fruit = namedtuple("Fruit", ["name", "quantity", "price"])
fruits = [
    Fruit(name="Banana", quantity=3, price=10),
    Fruit(name="Strawberry", quantity=12, price=8),
    Fruit(name="Kiwi", quantity=3, price=8),
    Fruit(name="Starfruit", quantity=1, price=12),
    Fruit(name="Orange", quantity=2, price=5),
]

Sort the fruit list by unit price (price divided by quantity).

If you enjoyed reading this week's post, share it with your friends and stay tuned for next week's post. See you then!