Articles > Chaining ContextManager values when mocking tests

Chaining ContextManager values when mocking tests

Written by
Holden Rehg
Posted on
January 29, 2020

R ecently, I’ve usedunittest.mock.patchto mock connections to external services while developing a Django app. While writing the tests, I found myself, repeatedly writing code that looks like:

def test_valid_connection_has_connected_status():
    with mock.patch("myapp.connection", MockConnection):
        with mock.patch("myapp.authenticator", MockAuthenticator):
            assertEqual(connect().status, "connected")

It’s not a huge deal at first, but as you add more use cases things really bloat up. I wanted utilities with a clear interface for the team to prevent mistakes, to make the tests more readable, and to save a few characters. Generally, the solution to, either with variables or functions, dynamically chain together context managers wasn’t perfectly clear (at least it wasn’t to me).

If you aren’t familiar with context managers, go read through the data model docs and the language reference docs on the with statement.


One potential solution we have is to pull out some common variables.

Looking back at our example, we may want to split out the authenticators so that we can easily handle both authenticated and unauthenticated scenarios.

I’m going to do that by extending theMockAuthenticatorto take in a parameter calledauthenticated.

class MockAuthenticator:
    def __init__(self, authenticated):
        self.authenticated = authenticated

    # We can trick the application into thinking
    # that it's initializing a new object, while
    # it's really just calling this function and
    # getting back the same object.
    def __call__(self, *args, **kwargs):
        return self

Then we can pull out some variables for each scenario to re-use throughout the entire test suite. This actually isn’t a bad option.

test_connection = mock.patch("myapp.connection", MockConnection)
authed = mock.patch("myapp.authenticator", MockAuthenticator(authenticated=True))
unauthed = mock.patch("myapp.authenticator", MockAuthenticator(authenticated=False))

def test_valid_connection_has_connected_status():
    with test_connection, authed:
        assertEqual(connect().status, "connected")

def test_invalid_connection_has_disconnected_status():
    with test_connection, unauthed:
        assertEqual(connect().status, "disconnected")

But I didn’t really like the fact that I had to redefine test_connectionon every test, no matter what type of connection I was using. I attempted to abstract those up into a single context manager, but ran into issues.

Thewithstatement can handle multiple “arguments” as you see in the example above, but you can’t pass an iterable or unpack an iterable the way you can with function calls. That would make things simpler because we would only be dealing with a couple of mock.patchvariables instantiated within a reusable authedfunction.

def authed():
    return (
        mock.patch("myapp.connection", MockConnection)
        mock.patch("myapp.authenticator", MockAuthenticator(authenticated=True))

def test_valid_connection_has_connected_status():
    # Does Not Work :(
    with authed():
        assertEqual(connect().status, "connected")

Yield functions

The best option in my mind is to pull out two functions that make it clear what is going on in each test, while mocking multiple objects behind the scenes.

We can do this by definingcontextmanagerfunctions, running each mock as needed and thenyielding to the calling function. The contextmanager function is a way to define a context manager without needing to have explicit __enter__and__exit__methods like you would in a class definition.

Now we redefine our tests. Check them out. I personally think they are much easier to understand at a glance, with a simple interface for auth mocks. Plus it will save us a few characters as an added bonus.

def test_valid_connection_has_connected_status():
    with authed():
        assertEqual(connect().status, "connected")

def test_invalid_connection_has_disconnected_status():
    with unauthed():
        assertEqual(connect().status, "connected")

A note about exit stacks

While it doesn’t make sense for the exact scenario we’re dealing with above, we could have also potentially used anExitStack.

I had no clue what anExitStackwas until I started hunting around for a way to chain together context managers dynamically. In the python docs, it’s defined as “[a way] to make it easy to programmatically combine other context managers and cleanup functions, especially those that are optional or otherwise driven by input data.”

Sounds great. The general idea behind them is that I can enter into theExitStackand have more control over the context managers that I push onto the stack. When theExitStackcloses then it iteratively pops every registered context manager off the stack and closes them in reverse order of their registration (first in last out).

Check out what that looks like:

from contextlib import ExitStack

with ExitStack() as stack:

# After leaving scope, the stack will close all registered context managers:
#     close file_4.txt
#     close file_3.txt
#     close file_2.txt
#     close file_1.txt

This is essentially the same as running nestedwith statements except you’ll see how we need it to abstract out an authedandunauthedcontext manager.

with open("file_1.txt"):
    with open("file_2.txt"):
        with open("file_3.txt"):
            with open("file_4.txt"):

The primary benefit is that while you cannot, for example, iterate over a list of files and pass a list ofopenresults into awithstatement, we can do that with the ExitStack.

We can take a look at another common example when dealing with mocks. Assume that you need to test a connection to an external API or database and you want to create a single test that interacts with a number of outside mocks.

You might start with something like this for each service adapter:

def test_mysql_database_adapters():
    with mock.patch("db.connection", MockMySQLConnection):

def test_psql_database_adapters():
    with mock.patch("db.connection", MockPostgresConnection):

def test_redis_database_adapters():
    with mock.patch("db.connection", MockRedisConnection):

def test_sqlite_database_adapters():
    with mock.patch("db.connection", MockSQLiteConnection):

Instead, we can create a context manager that iterates over all possible adapters, creates a new context for them via a db.connectfunction, and then yields those back to the test. Once that test has finished running, then theExitStackhandles all of the cleanup for those connections.

from contextlib import contextmanager, ExitStack

def db_env(adapters):
    with ExitStack() as stack:
        yield [stack.enter_context(db.connect(adapter) for adapter in adapters)]

def test_database_adapters():
    with db_env(SUPPORTED_ADAPTERS) as connections:
        for connection in connections:

Context managers can be an extremely useful abstraction in your python code if you don’t go overboard with them. Give it a shot in your code or tests to see where that might make sense.

Thanks For Reading

I appreciate you taking the time to read any of my articles. I hope it has helped you out in some way. If you're looking for more ramblings, take a look at theentire catalog of articles I've written. Give me a follow on Twitter or Github to see what else I've got going on. Feel free to reach out if you want to talk!

web development
unit testing
software testing

Holden Rehg, Author

Posted January 29, 2020