with
statement¶GitHub doesn't render large Jupyter Notebooks, so just in case, here is an nbviewer link to the notebook.
from contextlib import contextmanager
@contextmanager
def print_blue():
print('\033[34m', end='')
yield
print('\033[39m', end='')
with print_blue():
print('Changes color in context')
print('Outside the context with default color')
Changes color in context
Outside the context with default color
Intentionally, suppress expected error. This approach reduces visual noise of try/except
from contextlib import suppress
import os
with suppress(FileNotFoundError):
os.remove('file.txt')
compared
import logging
@contextmanager
def debug_logging(logger_name: str, level: int):
logger = logging.getLogger(logger_name)
old_level = logger.getEffectiveLevel()
logger.setLevel(level)
try:
yield logger
finally:
logger.setLevel(old_level)
with debug_logging('my-logger', logging.DEBUG) as logger:
logger.debug('This will be printed')
logging\
.getLogger('my-logger')\
.info('This wont be logged because default level is WARNING')
@contextmanager
def tag(name):
print(f'<{name}>', end='')
yield
print(f'</{name}>', end='')
with tag('header'):
print('Tag body', end='')
<header>Tag body</header>
class Indenter:
def __init__(self):
self.level = 0
def __enter__(self):
self.level += 1
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.level -= 1
def print(self, text):
print(' ' * (self.level-1) + text)
with Indenter() as indenter:
indenter.print('def mimic_python_syntax():')
with indenter:
indenter.print('s = "Hello World"')
indenter.print('print(s)')
indenter.print('\nmimic_python_syntax()')
def mimic_python_syntax(): s = "Hello World" print(s) mimic_python_syntax()
Make copy of input list items, work on copy. If there is no error, replace input list items with items from list used in context.
@contextmanager
def list_transaction(list_: list):
working = list(list_)
yield working
list_[:] = working
items = [1,2,3]
with list_transaction(items) as working:
working.append(4)
raise RuntimeError()
--------------------------------------------------------------------------- RuntimeError Traceback (most recent call last) <ipython-input-8-e92364a7c13f> in <module> 3 with list_transaction(items) as working: 4 working.append(4) ----> 5 raise RuntimeError() RuntimeError:
items
with list_transaction(items) as working:
working.append(4)
print(items)
Rolback changes if error occurs in context
class Transaction:
def __init__(self, connection):
self.connection = connection
def __enter__(self):
return self.connection
def __exit__(self, err_type, err_value, err_traceback):
if err_type:
self.connection.rollback()
else:
self.connection.commit()
import sqlite3
connection = sqlite3.connect('')
with Transaction(connection) as t:
t.execute("""
CREATE TABLE Users
(
id INTEGER PRIMARY KEY,
name TEXT NOT NULL
)
""")
t.execute("""
INSERT INTO Users
(id, name) VALUES (1, 'Name 1')
""")
print(t.execute("""SELECT * FROM Users""").fetchall())
[(1, 'Name 1')]
example from Python Cookbook 3rd Edition
from socket import socket, AF_INET, SOCK_STREAM
class LazyConnection:
def __init__(self, address, family=AF_INET, type=SOCK_STREAM):
self.address = address
self.family = AF_INET
self.type = SOCK_STREAM
self.sock = None
def __enter__(self):
if self.sock is not None:
raise RuntimeError('Already connected')
self.sock = socket(self.family, self.type)
self.sock.connect(self.address)
return self.sock
def __exit__(self, exc_ty, exc_val, tb):
self.sock.close()
self.sock = None
from functools import partial
connection = LazyConnection(('www.python.org', 80))
with connection as s:
s.send(b'GET /index.html HTTP/1.0\r\n')
s.send(b'Host: www.python.org\r\n')
s.send(b'\r\n')
resp = b''.join(iter(partial(s.recv, 8192), b''))
print(resp)
b'HTTP/1.1 301 Moved Permanently\r\nServer: Varnish\r\nRetry-After: 0\r\nLocation: https://www.python.org/index.html\r\nContent-Length: 0\r\nAccept-Ranges: bytes\r\nDate: Wed, 14 Apr 2021 09:31:05 GMT\r\nVia: 1.1 varnish\r\nConnection: close\r\nX-Served-By: cache-ams21030-AMS\r\nX-Cache: HIT\r\nX-Cache-Hits: 0\r\nX-Timer: S1618392666.941079,VS0,VE0\r\nStrict-Transport-Security: max-age=63072000; includeSubDomains\r\n\r\n'
import time
@contextmanager
def stopwatch(label: str):
start = time.time()
try:
yield
finally:
end = time.time()
print(f'{label}: {end - start}')
with stopwatch('Sleeping'):
time.sleep(1)
Sleeping: 1.0047638416290283
import time
class Stopwatch:
def __init__(self, output_callable):
self.output_callable = output_callable
def __enter__(self):
self.start = time.time()
def __exit__(self, err_type, err_value, err_traceback):
end = time.time()
self.output_callable(end - self.start)
import logging
logging.basicConfig()
logger = logging.getLogger('stopwatch')
logger.setLevel(logging.DEBUG)
with Stopwatch(logger.info):
time.sleep(1)
INFO:stopwatch:1.00174880027771
class MultilevelStopwatch:
def __init__(self):
self.levels = []
def __enter__(self):
self.levels.append(time.time())
return self
def __exit__(self, err_type, err_value, err_traceback):
latest = self.levels.pop()
end = time.time()
print(f'Level {len(self.levels)+1} took: {end - latest}')
with MultilevelStopwatch() as ms:
time.sleep(.5)
with ms:
time.sleep(.5)
Level 2 took: 0.5035400390625 Level 1 took: 1.0045082569122314
reuse of TCP connection to improve performance
# install requests if necessary
# !pip3 install requests
import requests
n = 20
with stopwatch('Using context manager'):
with requests.Session() as session:
for _ in range(n):
session.get("http://httpbin.org/cookies/set/sessioncookie/123456789")
with stopwatch('Establishing HTTP connection for every request'):
for _ in range(n):
requests.get("http://httpbin.org/cookies/set/sessioncookie/123456789")
Using context manager: 7.110113143920898 Establishing HTTP connection for every request: 10.991943836212158
from contextlib import contextmanager
@contextmanager
def get_state(name):
print("entering:", name)
try:
yield name
finally:
print("exiting:", name)
with get_state("A") as A, get_state('B') as B, get_state("C") as C:
print("inside with statement:", A, B, C)
entering: A entering: B entering: C inside with statement: A B C exiting: C exiting: B exiting: A
Example above written using ExitStack
from contextlib import ExitStack
with ExitStack() as es:
es.enter_context(get_state('A'))
es.enter_context(get_state('B'))
es.enter_context(get_state('C'))
print('Inside')
entering: A entering: B entering: C Inside exiting: C exiting: B exiting: A
Previously opened contexts exit casually
@contextmanager
def raise_error(name, err):
print("entering:", name)
try:
raise err()
finally:
print('exiting:', name)
try:
with get_state("A") as A, raise_error('B', RuntimeError) as B, get_state("C") as C:
print('Inside')
except RuntimeError as e:
print('Caught error', e)
entering: A entering: B exiting: B exiting: A Caught error
import logging
from contextlib import contextmanager
import traceback
import sys
logging.getLogger(__name__)
logging.basicConfig(
level=logging.INFO,
format="\n(asctime)s [%(levelname)s] %(message)s",
)
class Divider:
@contextmanager
def errorhandler(self):
try:
yield
except ZeroDivisionError:
print(
f"Custom handling of Zero Division Error! Printing "
"only 2 levels of traceback.."
)
logging.exception("ZeroDivisionError")
def __call__(self, a, b):
"""Function that we want to save from nasty error handling logic."""
with self.errorhandler():
return a / b
divide = Divider()
divide(2, 0)
ERROR:root:ZeroDivisionError Traceback (most recent call last): File "<ipython-input-26-fc2598d2335e>", line 19, in errorhandler yield File "<ipython-input-26-fc2598d2335e>", line 31, in __call__ return a / b ZeroDivisionError: division by zero
Custom handling of Zero Division Error! Printing only 2 levels of traceback..