from peewee import * from playhouse.sqlite_ext import * db = SqliteExtDatabase('/tmp/categories.db') db.load_extension('/home/charles/envs/notebook/closure') # Note we leave off .so class Category(Model): name = CharField() parent = ForeignKeyField('self', null=True, related_name='children') class Meta: database = db # Create a virtual table using the transitive closure extension. CategoryClosure = ClosureTable(Category) # Create the tables if they do not exist already. Category.create_table(True) CategoryClosure.create_table(True) Category.delete().execute() # Blow away any data. books = Category.create(name='Books') fiction = Category.create(name='Fiction', parent=books) scifi = Category.create(name='Sci-fi', parent=fiction) westerns = Category.create(name='Westerns', parent=fiction) non_fiction = Category.create(name='Non-fiction', parent=books) # Using a JOIN. all_descendants = (Category .select() .join(CategoryClosure, on=(Category.id == CategoryClosure.id)) .where(CategoryClosure.root == books)) print [cat.name for cat in all_descendants] # Using a subquery instead. "<<" translates to "IN". subquery = (CategoryClosure .select(CategoryClosure.id) .where(CategoryClosure.root == books)) all_descendants = Category.select().where(Category.id << subquery) print [cat.name for cat in all_descendants] # Using the helper method. all_descendants = CategoryClosure.descendants( books, include_node=True) print [cat.name for cat in all_descendants] # We can use just the Category table in this case. direct_descendants = Category.select().where(Category.parent == fiction) print [cat.name for cat in direct_descendants] # We can join on the closure table. direct_descendants = (Category .select() .join( CategoryClosure, on=(Category.id == CategoryClosure.id)) .where( (CategoryClosure.root == fiction) & (CategoryClosure.depth == 1))) print [cat.name for cat in direct_descendants] # We can use a subquery. subquery = (CategoryClosure .select(CategoryClosure.id) .where( (CategoryClosure.root == fiction) & (CategoryClosure.depth == 1))) direct_descendants = Category.select().where(Category.id << subquery) print [cat.name for cat in direct_descendants] # Using helper method. direct_descendants = CategoryClosure.descendants(fiction, depth=1) # We can use just the Category table. siblings = Category.select().where(Category.parent == scifi.parent) print [cat.name for cat in siblings] # Or use the closure table. siblings = (Category .select() .join(CategoryClosure, on=(Category.id == CategoryClosure.id)) .where( (CategoryClosure.root == scifi.parent) & (CategoryClosure.depth == 1))) print [cat.name for cat in siblings] # Using helper method. siblings = CategoryClosure.siblings(scifi, include_node=True) print [cat.name for cat in siblings] # Using a JOIN. ancestors = (Category .select() .join(CategoryClosure, on=(Category.id == CategoryClosure.root)) .where(CategoryClosure.id == scifi)) print [cat.name for cat in ancestors] # Using multiple tables in the FROM clause. ancestors = (Category .select() .from_(Category, CategoryClosure) .where( (Category.id == CategoryClosure.root) & (CategoryClosure.id == scifi))) print [cat.name for cat in ancestors] # Using helper method. ancestors = CategoryClosure.ancestors(scifi, include_node=True) print [cat.name for cat in ancestors] from collections import namedtuple import operator import os import time from peewee import * from playhouse.sqlite_ext import * from playhouse.test_utils import count_queries # Transitive-closure database. db_tc = SqliteExtDatabase('/tmp/bench-tc.db') db_tc.load_extension('/home/charles/envs/notebook/closure') # Materialized path database. db_mp = SqliteExtDatabase('/tmp/bench-mp.db') class CategoryTC(Model): name = CharField() parent = ForeignKeyField('self', index=True, null=True) depth = IntegerField(default=0) class Meta: database = db_tc def __unicode__(self): return self.name def save(self, *args, **kwargs): if self.parent: self.depth = self.parent.depth + 1 return super(CategoryTC, self).save(*args, **kwargs) @classmethod def query_direct_children(cls, name): cat = cls.get(cls.name == name) return (cls .select() .join(CategoryClosure, on=(cls.id == CategoryClosure.id)) .where((CategoryClosure.root == cat) & (CategoryClosure.depth == 1))) @classmethod def query_all_children(cls, name): cat = cls.get(cls.name == name) return (cls .select() .join(CategoryClosure, on=(cls.id == CategoryClosure.id)) .where(CategoryClosure.root == cat)) CategoryClosure = ClosureTable(CategoryTC) class CategoryMP(Model): name = CharField() path = CharField(index=True) parent = ForeignKeyField('self', index=True, null=True) depth = IntegerField(default=0) class Meta: database = db_mp def save(self, *args, **kwargs): if self.parent: self.path = self.parent.path + '.' + self.name self.depth = self.parent.depth + 1 else: self.path = self.name self.depth = 0 inst = super(CategoryMP, self).save(*args, **kwargs) @classmethod def query_direct_children(cls, name): cat = cls.get(cls.name == name) return cls.select().where( (cls.path.startswith(cat.path + '.')) & (cls.depth == cat.depth + 1)) @classmethod def query_all_children(cls, name): cat = cls.get(cls.name == name) return cls.select().where(cls.path.startswith(cat.path + '.')) # Helper class to measure the time it takes to run one or more operations. class timed(object): def __init__(self): self.time = None def __enter__(self): self.start = time.time() return self def __exit__(self, *args, **kwargs): self.time = time.time() - self.start # Helper function to construct a tree given a list of node counts # describing the branching factor at each level of the tree. def create_tree(model_class, node_counts): def build_tree(parent, idx): if idx == len(node_counts): return for i in range(node_counts[idx]): if parent: name = '%s.%s' % (parent.name, i) else: name = str(i) inst = model_class.create(name=name, parent=parent) build_tree(inst, idx + 1) build_tree(None, 0) # Helper function to query the table for all descendants and direct # descendants. The query is "consumed" by calling `list`, otherwise # it would not be evaluated. def query(model_class, names, iters=3): for name in names: for i in range(iters): list(model_class.query_all_children(name)) list(model_class.query_direct_children(name)) # Helper function to generate a list of node names to query. def make_names(structure): depth = len(structure) names = [] for i in range(depth): n_at_depth = min(structure[i], 10) nums = ['0'] * (i + 1) names.append('.'.join(nums)) for i in range(1, n_at_depth): nums[-1] = str(i) names.append('.'.join(nums)) return names # Showing how make_names works: make_names([2, 2, 2]) # Helper class to store the results of a particular benchmark. class Result(namedtuple('_R', ('structure', 'time', 'queries', 'timings'))): def __new__(cls, structure, time, queries): return super(Result, cls).__new__(cls, structure, time, queries, []) def __repr__(self): return '' % ( self.nodes, self.time, self.avg) @property def nodes(self): return sum(reduce(operator.mul, self.structure[:i + 1]) for i in range(len(self.structure))) @property def avg(self): return sum(timing for _, timing in self.timings) / len(self.timings) def add_result(self, name, timing): depth = len(name.split('.')) self.timings.append((depth, timing)) # The actual benchmarking code. def benchmark(model_class, structure): # Ensure the database connection is closed, if it is not already. if not db_tc.is_closed(): db_tc.close() if not db_mp.is_closed(): db_mp.close() # Delete any database files on disk so we are starting fresh each time. filenames = ['/tmp/bench-tc.db', '/tmp/bench-mp.db'] for filename in filenames: if os.path.exists(filename): os.unlink(filename) # Re-create the tables. CategoryTC.create_table() CategoryClosure.create_table() CategoryMP.create_table() with timed() as timer: with count_queries() as cq: create_tree(model_class, structure) result = Result(structure, timer.time, cq.count) names = make_names(structure) for name in names: with timed() as timer: query(model_class, names) result.add_result(name, timer.time) return result structures = [ ('Deep', [3, 3, 3, 2, 2, 2, 2, 2, 2]), ('Very deep', [2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1]), ('Uniform', [5, 5, 5, 5, 3]), ('Top-heavy', [50, 5, 5, 1]), ('Bottom-heavy', [1, 1, 5, 10, 50]), ('Very wide', [1000, 1]), ] # Disable logging while running the benchmark. import logging logging.disable(logging.DEBUG) # Run the benchmark! for label, structure in structures: print label for model_class in [CategoryMP, CategoryTC]: result = benchmark(model_class, structure) print '%s: %r' % (model_class.__name__, result)