In this section we will see common problems when mocking our code.
patch
not patching the correct object's attribute¶It is very common to spend a lot of time figuring out why patch
does not working.
%%writefile database.py
class DBConnection:
def __init__(self, dsn):
print('Connected to real database')
self.dsn = dsn
def cursor(self):
return Cursor()
def commit(self):
print('Saved changes')
class Cursor:
def execute(self, query):
print("Executed query={}".format(query))
def connect(dsn):
return DBConnection(dsn)
Overwriting database.py
%%writefile script.py
from database import connect
def clean_db():
conn = connect(dsn="user='123', password='xxx', host='hotels.prod.aws.com'")
cursor = conn.cursor()
cursor.execute('TRUNCATE clickouts')
cursor.execute('TRUNCATE images')
conn.commit()
Overwriting script.py
from unittest.mock import Mock, call, patch
from script import clean_db
def test_clean_db():
with patch('database.connect') as db_mock:
clean_db()
assert db_mock().cursor().method_calls == [
call.execute('TRUNCATE clickouts'),
call.execute('TRUNCATE images')
]
test_clean_db()
Connected to real database Executed query=TRUNCATE clickouts Executed query=TRUNCATE images Saved changes
--------------------------------------------------------------------------- AssertionError Traceback (most recent call last) <ipython-input-3-9189f24e6aa0> in <module>() 10 ] 11 ---> 12 test_clean_db() <ipython-input-3-9189f24e6aa0> in test_clean_db() 7 assert db_mock().cursor().method_calls == [ 8 call.execute('TRUNCATE clickouts'), ----> 9 call.execute('TRUNCATE images') 10 ] 11 AssertionError:
... But we patched database.connect
! Why did it connect to the real database and execute queries?
Explanation: patch('database.connect')
is not patching script.connect
because it's a copy of database.connect
reference.
Before patching
database.connect = <ORIGINAL database.connect function>
database.Database = <ORIGINAL database.Database class>
script.connect = <ORIGINAL database.connect function>
script.clean_db = <ORIGINAL script.clean_db function>
After patching
database.connect = <db_mock>
database.Database = <ORIGINAL database.Database class>
script.connect = <ORIGINAL database.connect function>
scriptclean_db = <ORIGINAL script.clean_db function>
What patch('database.connect')
does is patch the attribute connect
of database
module:
def patch('database.connect'):
import database
original_function = database.connect
database.connect = db_mock = Mock()
yield db_mock
database.connect = original_function
Attribute connect
of script.py
module is a copy of the original reference to <ORIGINAL database.connect function>
.
Be careful or reference copies you import on your module.
Possible fixes:
script.py
, replace from database import connect
to import database
and use database.connect
.patch('script.connect')
patch('database.DBConnection')
from unittest.mock import Mock, patch, call
from script import clean_db
def test_clean_db():
with patch('database.DBConnection') as db_mock:
clean_db()
assert db_mock().cursor().method_calls == [
call.execute('TRUNCATE clickouts'),
call.execute('TRUNCATE images')
]
test_clean_db()
self
in side_effect
¶A common headache appears when trying to patch a class method with a custom function which receives self
as a parameter, like so:
from unittest.mock import Mock, MagicMock, patch
class Table:
def __init__(self, name):
self.table_name = name
def get_rows(self):
print("Retrieve rows from database")
return [1, 2, 3]
def get_all_data():
users = Table('users')
jobs = Table('jobs')
return {'users': users.get_rows(),
'jobs': jobs.get_rows()}
One way to test the function get_all_data
would be to patch Table.get_rows
function to make it return pre-defined rows based on the value of self.table_name
, like so:
row_data = {'users': ['user_row_1',
'user_row_2'],
'jobs': ['job_row_1',
'job_row_2',
'job_row_3']}
# with patch.object(Table, 'get_rows', side_effect=WHAT DO WE INSERT HERE?):
The problem comes when we want to define the side_effect
. If you try this, it wouldn't work:
with patch.object(Table, 'get_rows', side_effect=lambda self: row_data[self.table_name]):
assert get_all_data() == row_data
--------------------------------------------------------------------------- TypeError Traceback (most recent call last) <ipython-input-9-e5197812737e> in <module>() 1 with patch.object(Table, 'get_rows', side_effect=lambda self: row_data[self.table_name]): ----> 2 assert get_all_data() == row_data <ipython-input-6-b2612a3d4105> in get_all_data() 11 users = Table('users') 12 jobs = Table('jobs') ---> 13 return {'users': users.get_rows(), 14 'jobs': jobs.get_rows()} /usr/local/lib/python3.6/unittest/mock.py in __call__(_mock_self, *args, **kwargs) 937 # in the signature 938 _mock_self._mock_check_sig(*args, **kwargs) --> 939 return _mock_self._mock_call(*args, **kwargs) 940 941 /usr/local/lib/python3.6/unittest/mock.py in _mock_call(_mock_self, *args, **kwargs) 1003 return result 1004 -> 1005 ret_val = effect(*args, **kwargs) 1006 1007 if (self._mock_wraps is not None and TypeError: <lambda>() missing 1 required positional argument: 'self'
The parameter self
is not passed to our side_effect, and we want it.
If we check what get_rows
is, we will see:
print(Table.get_rows)
print(Table('users').get_rows)
with patch.object(Table, 'get_rows', side_effect=lambda self: row_data[self.table_name]) as mock_get_rows:
print(mock_get_rows)
print(Table.get_rows)
print(Table('users').get_rows)
<function Table.get_rows at 0x7fe294c7f620> <bound method Table.get_rows of <__main__.Table object at 0x7fe294c45588>> <MagicMock name='get_rows' id='140611135231592'> <MagicMock name='get_rows' id='140611135231592'> <MagicMock name='get_rows' id='140611135231592'>
The instances get_rows
method are bounded to the instance.
The difference between a function
and a bound method
is that the self
(instance object) parameter is automatically added to the arguments being called in the bounded method when calling it.
users = Table('users')
Table.get_rows(users) # Calling it like this requires you to pass `self` (instance) attribute.
users.get_rows() # `self` is automatically passed, because it's a bounded method
Retrieve rows from database Retrieve rows from database
[1, 2, 3]
What mocking library does is:
original_get_rows = Table.get_rows
try:
mock_get_rows = MagicMock()
mock_get_rows.side_effect = lambda self: row_data[self.table_name] # function, not bounded!
Table.get_rows = mock_get_rows
users = Table('users')
print(Table.get_rows(users)) # This works
print(users.get_rows)
print(users.get_rows()) # This doesn't
finally:
Table.get_rows = original_get_rows
['user_row_1', 'user_row_2'] <MagicMock id='140611135411760'>
--------------------------------------------------------------------------- TypeError Traceback (most recent call last) <ipython-input-15-060b8d59ac33> in <module>() 8 print(Table.get_rows(users)) # This works 9 print(users.get_rows) ---> 10 print(users.get_rows()) # This doesn't 11 finally: 12 Table.get_rows = original_get_rows /usr/local/lib/python3.6/unittest/mock.py in __call__(_mock_self, *args, **kwargs) 937 # in the signature 938 _mock_self._mock_check_sig(*args, **kwargs) --> 939 return _mock_self._mock_call(*args, **kwargs) 940 941 /usr/local/lib/python3.6/unittest/mock.py in _mock_call(_mock_self, *args, **kwargs) 1003 return result 1004 -> 1005 ret_val = effect(*args, **kwargs) 1006 1007 if (self._mock_wraps is not None and TypeError: <lambda>() missing 1 required positional argument: 'self'
Solutions
Option 1 - Use patch.object
to temporarily assign a new get_rows
function without a Mock object:
with patch.object(Table, 'get_rows', new=lambda self: row_data[self.table_name]) as mock_get_rows:
assert get_all_data() == row_data
print(mock_get_rows)
print(Table.get_rows)
print(Table('aaa').get_rows)
<function <lambda> at 0x7fe29554ed08> <function <lambda> at 0x7fe29554ed08> <bound method <lambda> of <__main__.Table object at 0x7fe294cf72e8>>
Pros: Short, and works
Cons: You lose all the call
history being done in get_rows
that you get if you used Mock
as new
.
Option 2: Create one mocked Table
instance for each expected. This way you can pre-define the return_values for each instance and not need self
.
from unittest.mock import patch, MagicMock, Mock, call
def mocked_table_instance(table_name, rows):
table_inst = MagicMock(table_name=table_name)
table_inst.get_rows.return_value = rows
return table_inst
mocked_tables = {'users': mocked_table_instance('users', row_data['users']),
'jobs': mocked_table_instance('jobs', row_data['jobs'])}
with patch('__main__.Table', side_effect=lambda table_name: mocked_tables[table_name]) as table_class_mock:
assert get_all_data() == row_data
assert table_class_mock.mock_calls == [call('users'),
call('jobs')]
assert mocked_tables['users'].method_calls == [call.get_rows()]
assert mocked_tables['jobs'].method_calls == [call.get_rows()]
try:
Table('table that is not mocked')
except KeyError as e:
print(repr(e))
KeyError('table that is not mocked',)
Pros: Has call history. Common mocked_table_instance
function can be used in multiple tests, centralised way of mocking Table
.
Cons: Longer patching.
Option 3: Create a class which simulates Table
.
class TableMock:
def __init__(self, table_name):
self.table_name = table_name
def mock_set_rows(self, rows):
self.rows = rows
def get_rows(self):
return self.rows
mocked_tables = {}
for name, rows in row_data.items():
mocked_tables[name] = TableMock(name)
mocked_tables[name].mock_set_rows(rows)
with patch('__main__.Table', side_effect=lambda table_name: mocked_tables[table_name]) as table_class_mock:
assert Table('users').get_rows() == row_data['users']
Pros: Easiest to extend. Depending on how you do it, every class method is notimplemented by default (good if you forgot to patch a method which touches real files/databases).
Cons: No call history
Let's say you want to unit test the function Database.copy_from(other_db)
does some calls:
class Database:
def copy_from(self, other_db, drop_all=False):
if drop_all:
self.delete_all()
self.create()
self.add_users(other_db.get_users())
self.add_jobs(other_db.get_jobs())
self.add_categories(other_db.get_categories())
self.commit()
# Ugly way of defining all other functions
def noop():
pass
delete_all = create = add_users = add_jobs = add_categories = commit = noop
One usual way of doing it would be to:
from unittest.mock import patch, DEFAULT
with patch.multiple(Database, delete_all=DEFAULT, add_users=DEFAULT, add_jobs=DEFAULT, add_categories=DEFAULT, commit=DEFAULT) as mock_db:
other_db_mock = Mock()
db = Database()
db.copy_from(other_db_mock)
There are a few problems here:
patch
line is very long. Same would happen if we use multiple patch.object(Database, function=blabla)
(even longer). Gets worse when having to define return_value
and side_effect
method_calls
and mock_calls
not available, because Database
class is not mocked.Database
function that should never be executed in tests after refactoring/adding more add_xxxx
functions.A new way of unittesting a single method from a class while automatically patching all others would be to call Database.copy_from
(unbounded method!) with a Mock
object:
db_mock = Mock()
other_db_mock = Mock()
Database.copy_from(db_mock, other_db_mock)
assert db_mock.method_calls == [
call.add_users(other_db_mock.get_users()),
call.add_jobs(other_db_mock.get_jobs()),
call.add_categories(other_db_mock.get_categories()),
call.commit()
]
In case you want to define custom return_values or side_effects to their methods, it is pretty easy and clean:
db_mock = Mock()
db_mock.delete_all.side_effect = Exception("UNEXPECTED CALL!")
db_mock.create.side_effect = Exception("UNEXPECTED CALL!")
db_mock.commit.return_value = True
other_db_mock = Mock()
Database.copy_from(db_mock, other_db_mock)
In case you want db_mock
to have all attributes that are created/initialized in __init__
:
db_mock = Mock()
Database.__init__(db_mock)
Following previous approach, it is a bit tricky to call Database.method
if the method
is a staticmethod
or classmethod
:
class Algorithms:
@classmethod
def cfib(cls, x):
print("> called {}.cfib({})".format(cls, x))
if x < 2:
return x
return cls.cfib(x-1) + cls.cfib(x-2)
We will see that Algorithms.cfib
is bounded to the class:
m = Mock()
print(Algorithms.cfib)
<bound method Algorithms.cfib of <class '__main__.Algorithms'>>
Since the function is already bounded (to the class), we can't pass our own cls
object:
m = Mock()
Algorithms.cfib(m, 5)
--------------------------------------------------------------------------- TypeError Traceback (most recent call last) <ipython-input-33-1f19b728f1b6> in <module>() 1 m = Mock() ----> 2 Algorithms.cfib(m, 5) TypeError: cfib() takes 2 positional arguments but 3 were given
The solution is to unbound the function, which can be done by accessing the bounded method's attribute __func__
:
print(Algorithms.cfib)
print(Algorithms.cfib.__func__)
<bound method Algorithms.cfib of <class '__main__.Algorithms'>> <function Algorithms.cfib at 0x7fe294c7f6a8>
m = MagicMock() # So that m.cfib returns a MagicMock, which you can sum with another MagicMock
print(Algorithms.cfib.__func__(m, 5))
print(m.method_calls)
> called <MagicMock id='140611134985888'>.cfib(5) <MagicMock name='mock.cfib().__add__()' id='140611135479704'> [call.cfib(4), call.cfib(3)]
Another issue comes when you try to make m.cfib
work like Algorithms.cfib
does:
m = MagicMock()
m.cfib = Algorithms.cfib
print(Algorithms.cfib.__func__(m, 3))
print(m.method_calls)
> called <MagicMock id='140611134984544'>.cfib(3) > called <class '__main__.Algorithms'>.cfib(2) > called <class '__main__.Algorithms'>.cfib(1) > called <class '__main__.Algorithms'>.cfib(0) > called <class '__main__.Algorithms'>.cfib(1) 2 []
Does not work because cls
is not our mock
object. The object m.cfib
is a function bounded to Algorithms
, not our mock
!
It is possible to change cfib
and make it a function bounded to m
:
import types
m = MagicMock()
m.cfib = types.MethodType(Algorithms.cfib.__func__, m)
print(m.cfib(3))
print(m.method_calls)
> called <MagicMock id='140611134891848'>.cfib(3) > called <MagicMock id='140611134891848'>.cfib(2) > called <MagicMock id='140611134891848'>.cfib(1) > called <MagicMock id='140611134891848'>.cfib(0) > called <MagicMock id='140611134891848'>.cfib(1) 2 []
Perfect, but where is our call history? If you want call history, you must use side_effect or return_value.
Solution:
m = MagicMock()
m.cfib.side_effect = types.MethodType(Algorithms.cfib.__func__, m)
print(m.cfib(4))
print(m.method_calls)
> called <MagicMock id='140611134891624'>.cfib(4) > called <MagicMock id='140611134891624'>.cfib(3) > called <MagicMock id='140611134891624'>.cfib(2) > called <MagicMock id='140611134891624'>.cfib(1) > called <MagicMock id='140611134891624'>.cfib(0) > called <MagicMock id='140611134891624'>.cfib(1) > called <MagicMock id='140611134891624'>.cfib(2) > called <MagicMock id='140611134891624'>.cfib(1) > called <MagicMock id='140611134891624'>.cfib(0) 3 [call.cfib(4), call.cfib(3), call.cfib(2), call.cfib(1), call.cfib(0), call.cfib(1), call.cfib(2), call.cfib(1), call.cfib(0)]
Use wraps
to track history of calls on Algorithms.fib
(with recursion too!):
class Algorithms:
@staticmethod
def fib(x):
if x < 2:
return x
return Algorithms.fib(x-1) + Algorithms.fib(x-2)
with patch.object(Algorithms, 'fib', wraps=Algorithms.fib) as fib_mock:
Algorithms.fib(4)
print(fib_mock.mock_calls)
with patch.object(Algorithms, 'fib', side_effect=Algorithms.fib) as fib_mock:
Algorithms.fib(4)
print(fib_mock.mock_calls)
[call(4), call(3), call(2), call(1), call(0), call(1), call(2), call(1), call(0)] [call(4), call(3), call(2), call(1), call(0), call(1), call(2), call(1), call(0)]
In case you don't want recursion and just check that fib(x)
calls fib(x-1)
and fib(x-2)
:
orig_fib = Algorithms.fib
with patch.object(Algorithms, 'fib', return_value=0) as fib_mock:
orig_fib(4)
print(fib_mock.mock_calls)
[call(3), call(2)]
When you attach a mock as an attribute of another mock, it becomes a "child" of that mock. Calls to the child are recorded in the method_calls
and mock_calls
attributes of the parent.
If the child Mock
has a name
, the parent will not see this child method_calls:
m = Mock(name='parent')
child1 = Mock()
child2 = Mock(name='child_two')
m.child1 = child1
m.child2 = child2
child1('abc')
child2(1, 2, 3)
print(m.method_calls)
print(m.mock_calls)
[call.child1('abc')] [call.child1('abc')]
Mocks created by patch()
are automatically given names:
import datetime
m = Mock()
with patch('datetime.datetime') as child1:
m.datetime = child1
print(datetime.datetime.now())
print(m.method_calls)
print(m.datetime.method_calls)
<MagicMock name='datetime.now()' id='140611134495320'> [] [call.now()]
To attach mocks that have names to a parent, you can use the Mock
method attach_mock
:
import datetime
m = Mock()
with patch('datetime.datetime') as child1:
m.attach_mock(child1, 'datetime')
print(datetime.datetime.now())
print(m.method_calls)
print(m.datetime.method_calls)
<MagicMock name='mock.datetime.now()' id='140611134419184'> [call.datetime.now()] [call.now()]
async
methods¶To test projects which work with async
and coroutines
, I recommend using asynctest
library: https://pypi.python.org/pypi/asynctest/0.5.0
This library has a new mock object CoroutineMock
which lets you define return_value
and side_effect
of your functions without having to worry about them being async
, Future
objects or anything.
import asynctest
from asynctest import patch, Mock
import asyncio
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
class AsyncThing:
async def method(self):
asyncio.sleep(50)
return 5
def normal_method(self):
return 123
class TestSomething(asynctest.TestCase):
use_default_loop = True
async def test_something(self):
a = AsyncThing()
with patch('asyncio.sleep'):
x = await a.method()
assert x == 5
async def test_class_mock(self):
# Using `spec` makes it create CoroutineMock or MagicMock, depending on if the method is async or not.
# Make sure to import patch from asynctest!
with patch('__main__.AsyncThing', spec=AsyncThing) as asyncthing_mock:
print(asyncthing_mock.method)
print(asyncthing_mock.normal_method)
ts = TestSomething()
suite = asynctest.TestLoader().loadTestsFromModule(ts)
asynctest.TextTestRunner().run(suite)
..
<CoroutineMock name='AsyncThing.' id='140611135202976'> <MagicMock name='AsyncThing.' id='140611135202584'>
---------------------------------------------------------------------- Ran 2 tests in 0.016s OK
<unittest.runner.TextTestResult run=2 errors=0 failures=0>
One missing feature from asynctest
is the behaviour with async with
(__aenter__
and __aexit__
methods). By default, using asynctest
would not work.
Current "best known" solution is to create a new Mock
class:
class AsyncContextManagerMock(MagicMock):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
setattr(self, 'aenter_return', kwargs.get('aenter_return', MagicMock()))
setattr(self, 'aexit_return', kwargs.get('aexit_return', None))
async def __aenter__(self):
return self.aenter_return
async def __aexit__(self, exc_type, exc_value, traceback):
return self.aexit_return
class TestSomething(asynctest.TestCase):
use_default_loop = True
async def test_async_with(self):
async with AsyncContextManagerMock() as mock:
print('first', mock)
async with AsyncContextManagerMock(aenter_return=5) as value:
print('second', value)
ts = TestSomething()
suite = asynctest.TestLoader().loadTestsFromModule(ts)
asynctest.TextTestRunner().run(suite)
.
first <MagicMock name='mock.aenter_return' id='140611128984688'> second 5
---------------------------------------------------------------------- Ran 1 test in 0.007s OK
<unittest.runner.TextTestResult run=1 errors=0 failures=0>