Return to Blog

Context managers in Python

By John Lekberg on October 11, 2020.


This week's post is about context managers in Python. You will learn:

What are context managers? Why should I care?

Context managers (introduced in PEP 343) are Python objects that handle setup and teardown for resources like open files and database connections.

You should care about context managers because doing this setup and teardown by hand requires extra code that you can't forget to write, every time you work with that resource.

For example, I have a file of numbers

data.txt

61
25
51
77
66
11
53
51
96
76

I want to use Python to find the sum of these numbers, so I write this code:

print("opening file")
file = open("data.txt")
total = 0
for line in file:
    total += int(line)
print("closing file")
file.close()
print("total is", total)
opening file
closing file
total is 567

This seems to work fine, but it fails to close the file if an exception is raised:

data.txt

61
25
51
77
not a number
11
53
51
96
76
print("opening file")
file = open("data.txt")
total = 0
for line in file:
    total += int(line)
print("closing file")
file.close()
print("total is", total)
opening file
ValueError: invalid literal for int() with base 10: 'not a number\n'

(Notice that closing file is not printed.)

I can fix this by using a try-statement with a "finally" clause:

print("opening file")
file = open("data.txt")
try:
    total = 0
    for line in file:
        total += int(line)
finally:
    print("closing file")
    file.close()
print("total is", total)
opening file
closing file
ValueError: invalid literal for int() with base 10: 'not a number\n'

(Notice that closing file is printed, even though the exception is raised.)

Because this "try-finally" pattern is so common, Python added context managers to simplify this sort of code.

For example, since files are context managers, I can take this code:

file = open("data.txt")
try:
    total = 0
    for line in file:
        total += int(line)
finally:
    file.close()

And rewrite it to use a with-statement that activates the file's context manager:

with open("data.txt") as file:
    total = 0
    for line in file:
        total += int(line)

Context managers are just objects with __enter__ and __exit__ methods. Any object with these methods is considered to be a context manager. __enter__ handles the setup. __exit__ handles the teardown.

Context managers are activated by the with-statement (introduced in PEP 343). __enter__ is called at the beginning of the statement, and __exit__ is called at the end of the statement. Any control flow that exits the statement -- like return, raise, break, or continue -- will also cause __exit__ to be called.

For the exact semantics of the with-statement, please refer to:

How Python uses context managers in the standard library

Python's standard library uses context managers in many different places, but there are 5 main patterns to keep in mind:

  1. Close and open files.
  2. Commit or rollback database transactions.
  3. Acquire and release concurrency locks.
  4. Start and shutdown concurrency/process managers.
  5. Handle specific scenarios with contextlib.

Close and open files. Objects that inherit from io.IOBase are context managers that call .close() on exit. This includes:

Here's an example:

with open("data.txt") as file:
    print(file.read())
61
25
51
77
not a number
11
53
51
96
76

Commit or rollback database transactions. sqlite3.Connection objects are context managers that will either commit or rollback a transaction, depending on how the context manager is exited. If the context manager exits normally, then the transaction is committed. If an exception is raised, then the transaction is rolled back.

Here's an example: I have a database of accounts with value transfers:

import sqlite3

conn = sqlite3.connect(":memory:")
conn.executescript("""
CREATE TABLE Account(name TEXT, change REAL);
INSERT INTO Account
VALUES ('Bob', 300), ('Henry', 200);
""")


def report():
    print(conn.execute("""
      SELECT name, sum(change)
        FROM Account
    GROUP BY name
    ORDER BY name ASC
    """).fetchall())


report()
[('Bob', 300.0), ('Henry', 200.0)]

Bob has $300, and Henry has $200. Then Henry transfers $100 to Bob:

with conn:
    conn.execute(
    	"INSERT INTO Account VALUES ('Bob', 100)"
    )
    conn.execute(
        "INSERT INTO Account VALUES ('Henry', -100)"
    )
report()
[('Bob', 400.0), ('Henry', 100.0)]

(Notice that this transaction was committed because there were no issues.)

Then Henry transfers $100 to Bob again, but an error occurs in the middle of the transaction:

with conn:
    conn.execute(
    	"INSERT INTO Account VALUES ('Bob', 100)"
    )
    raise Exception()
    conn.execute(
        "INSERT INTO Account VALUES ('Henry', -100)"
    )
Exception:
report()
[('Bob', 400.0), ('Henry', 100.0)]

(Notice that this transaction was rolled back because an exception was raised. Bob did not gain $100.)

Acquire and release concurrency locks. Concurrency lock objects are context managers that call .acquire() on enter and call .release() on exit. This includes:

Here's an example: I have two tasks that I want to run concurrently that both print output:

import threading
import time


def taskA():
    for i in range(5):
        print("A", i)
        time.sleep(1)


def taskB():
    for i in range(5):
        print("B", i)
        time.sleep(1)


threadA = threading.Thread(target=taskA)
threadB = threading.Thread(target=taskB)
threadA.start()
threadB.start()
AB 0
 0
AB 1
 1
BA 2
 2
BA 3
 3
AB 4
 4

The problem with this is that the outputs from the two tasks overlap. I can fix this by using a lock as a context manager:

print_lock = threading.Lock()


def taskA():
    for i in range(5):
        with print_lock:
            print("A", i)
        time.sleep(1)


def taskB():
    for i in range(5):
        with print_lock:
            print("B", i)
        time.sleep(1)


threadA = threading.Thread(target=taskA)
threadB = threading.Thread(target=taskB)
threadA.start()
threadB.start()
A 0
B 0
A 1
B 1
A 2
B 2
A 3
B 3
A 4
B 4

Start and shutdown concurrency/process managers. Concurrency/process managers are context managers that are started on enter, and are terminated on exit:

Here's an example: I use a ThreadPoolExecutor object to run two tasks concurrently and then print a message after the executor is shutdown:

from concurrent.futures import ThreadPoolExecutor
import threading
import time

print_lock = threading.Lock()


def taskA():
    for i in range(5):
        with print_lock:
            print("A", i)
        time.sleep(1)


def taskB():
    for i in range(5):
        with print_lock:
            print("B", i)
        time.sleep(1)


print("Get ready!")
with ThreadPoolExecutor() as executor:
    executor.submit(taskA)
    executor.submit(taskB)
print("All finished!")
Get ready!
A 0
B 0
A 1
B 1
A 2
B 2
B 3
A 3
A 4
B 4
All finished!

Handle specific scenarios with contextlib. The contextlib module provides several context managers for specific scenarios:

How to create context managers

You can create a context manager by implementing __enter__ and __exit__ methods on an object. E.g.

class Talky:
    def __enter__(self):
        print("Wow! Entering a context.")

    def __exit__(self, exc_type, exc_value, traceback):
        if exc_type is not None:
            print("Leaving with an error?!")
        else:
            print("Leaving normally. Boring...")


with Talky():
    pass
Wow! Entering a context.
Leaving normally. Boring...
with Talky():
    raise Exception()
Wow! Entering a context.
Leaving with an error?!
Exception:

What are the three parameters exc_type, exc_value, traceback?

For more details, see "3.3.9. With Statement Context Managers".

Even if you only need to use one of the methods (__enter__ or __exit__), you need to implement both of them. E.g. just implementing __enter__ causes an error:

class TalkEnter:
    def __enter__(self):
        print("Entering a context.")


with TalkEnter():
    pass
AttributeError: __exit__

Here are two useful shortcuts:

How to handle multiple context managers

A simple way to handle multiple context managers is nested with-statements. E.g.

with open("in.txt") as fin:
    with open("out.txt", "w") as fout:
        for line in fin:
            fout.write(line.upper())

You can also have multiple items in one with-statement, separated by commas. E.g.

with open("in.txt") as fin, open("out.txt", "w") as fout:
    for line in fin:
        fout.write(line.upper())

(This is equivalent to nesting the context managers, from left to right.)

But, if you have a large amount (or a variable amount) of context managers, then you should use contextlib.ExitStack. E.g.

from contextlib import ExitStack

with ExitStack() as stack:
    fin = stack.enter_context(open("in.txt"))
    fout = stack.enter_context(open("out.txt", "w"))
    for line in fin:
        fout.write(line.upper())

Here's a more complex example: I have 20 files, in-1.txt, in-2.txt, ..., in-20.txt of data. Each of these files is sorted, and I want to create out.txt that contains the sorted lines of all the input files. Here's how I would manage these files using ExitStack:

from contextlib import ExitStack
import heapq

filenames = [f"in-{n}.txt" for n in range(1, 21)]
with ExitStack() as stack:
    fins = [
        stack.enter_context(open(filename))
        for filename in filenames
    ]
    fout = stack.enter_context(open("out.txt", "w"))
    for line in heapq.merge(*fins):
        fout.write(line)

How to register callback functions with ExitStack

contextlib.ExitStack.callback can also be used to register callback functions that are executed when the context manager is exited. E.g.

from contextlib import ExitStack

with ExitStack() as stack:
    print(1)
    stack.callback(print, 2)
    print(3)
    stack.callback(print, 4)
1
3
4
2

I find this useful for joining threads in multithreaded code. E.g. Here's a task that I want to run in several threads:

import threading

print_lock = threading.Lock()


def task(n):
    for i in range(n, n + 3):
        with print_lock:
            print(f"{n}-{i}")
        time.sleep(1)

I could manage the threads by creating a list of threads, then looping over to that list to start the threads, and then looping over that list again to join the threads:

threads = [
    threading.Thread(target=task, args=[i])
    for i in range(5)
]

for thread in threads:
    thread.start()

for thread in threads:
    thread.join()
0-0
1-1
2-2
3-3
4-4
1-2
3-4
2-3
0-1
4-5
1-3
2-4
0-2
3-5
4-6

But, I can remove the need to explicitly keep track of threads for starting and joining by registering thread.join as a callback to an ExitStack:

with ExitStack() as stack:
    for i in range(5):
        thread = threading.Thread(target=task, args=[i])
        thread.start()
        stack.callback(thread.join)
0-0
1-1
2-2
3-3
4-4
1-2
4-5
0-1
2-3
3-4
1-3
4-6
0-2
2-4
3-5

In conclusion...

In this week's post you learned:

My challenge to you:

Create a context manager called tag that prints opening and closing XML tags. E.g.

with tag("body"):
    with tag("h1"):
        print("My Document")
    with tag("p"):
        print("Lorem ipsum")
    with tag("ul"):
        for i in range(4):
            with tag("li"):
                print(i)
<body>
<h1>
My Document
</h1>
<p>
Lorem ipsum
</p>
<ul>
<li>
0
</li>
<li>
1
</li>
<li>
2
</li>
<li>
3
</li>
</ul>
</body>

If you enjoyed this week's post, share it with your friends and stay tuned for next week's post. See you then!


(If you spot any errors or typos on this post, contact me via my contact page.)