SQLite transitive closure extension demo

In [2]:
from peewee import *
from playhouse.sqlite_ext import *

db = SqliteExtDatabase('/tmp/categories.db')
db.load_extension('/home/charles/envs/notebook/closure')  # Note we leave off .so
In [3]:
class Category(Model):
    name = CharField()
    parent = ForeignKeyField('self', null=True, related_name='children')

    class Meta:
        database = db

# Create a virtual table using the transitive closure extension.
CategoryClosure = ClosureTable(Category)
In [4]:
# Create the tables if they do not exist already.
Category.create_table(True)
CategoryClosure.create_table(True)
Category.delete().execute()  # Blow away any data.
Out[4]:
0
In [5]:
books = Category.create(name='Books')
fiction = Category.create(name='Fiction', parent=books)
scifi = Category.create(name='Sci-fi', parent=fiction)
westerns = Category.create(name='Westerns', parent=fiction)
non_fiction = Category.create(name='Non-fiction', parent=books)

Get all descendant nodes

In [7]:
# Using a JOIN.
all_descendants = (Category
                   .select()
                   .join(CategoryClosure, on=(Category.id == CategoryClosure.id))
                   .where(CategoryClosure.root == books))
print [cat.name for cat in all_descendants]

# Using a subquery instead. "<<" translates to "IN".
subquery = (CategoryClosure
            .select(CategoryClosure.id)
            .where(CategoryClosure.root == books))
all_descendants = Category.select().where(Category.id << subquery)
print [cat.name for cat in all_descendants]

# Using the helper method.
all_descendants = CategoryClosure.descendants(
    books, include_node=True)
print [cat.name for cat in all_descendants]
[u'Books', u'Fiction', u'Sci-fi', u'Westerns', u'Non-fiction']
[u'Books', u'Fiction', u'Sci-fi', u'Westerns', u'Non-fiction']
[u'Books', u'Fiction', u'Sci-fi', u'Westerns', u'Non-fiction']

Get direct descendant nodes (child nodes)

In [10]:
# We can use just the Category table in this case.
direct_descendants = Category.select().where(Category.parent == fiction)
print [cat.name for cat in direct_descendants]

# We can join on the closure table.
direct_descendants = (Category
                      .select()
                      .join(
                          CategoryClosure,
                          on=(Category.id == CategoryClosure.id))
                      .where(
                          (CategoryClosure.root == fiction) &
                          (CategoryClosure.depth == 1)))
print [cat.name for cat in direct_descendants]

# We can use a subquery.
subquery = (CategoryClosure
            .select(CategoryClosure.id)
            .where(
                (CategoryClosure.root == fiction) &
                (CategoryClosure.depth == 1)))
direct_descendants = Category.select().where(Category.id << subquery)
print [cat.name for cat in direct_descendants]

# Using helper method.
direct_descendants = CategoryClosure.descendants(fiction, depth=1)
[u'Sci-fi', u'Westerns']
[u'Sci-fi', u'Westerns']
[u'Sci-fi', u'Westerns']

Get all sibling nodes

In [13]:
# We can use just the Category table.
siblings = Category.select().where(Category.parent == scifi.parent)
print [cat.name for cat in siblings]

# Or use the closure table.
siblings = (Category
            .select()
            .join(CategoryClosure, on=(Category.id == CategoryClosure.id))
            .where(
                (CategoryClosure.root == scifi.parent) &
                (CategoryClosure.depth == 1)))
print [cat.name for cat in siblings]

# Using helper method.
siblings = CategoryClosure.siblings(scifi, include_node=True)
print [cat.name for cat in siblings]
[u'Sci-fi', u'Westerns']
[u'Sci-fi', u'Westerns']
[u'Sci-fi', u'Westerns']

Get all ancestors

In [14]:
# Using a JOIN.
ancestors = (Category
             .select()
             .join(CategoryClosure, on=(Category.id == CategoryClosure.root))
             .where(CategoryClosure.id == scifi))
print [cat.name for cat in ancestors]

# Using multiple tables in the FROM clause.
ancestors = (Category
             .select()
             .from_(Category, CategoryClosure)
             .where(
                 (Category.id == CategoryClosure.root) &
                 (CategoryClosure.id == scifi)))
print [cat.name for cat in ancestors]

# Using helper method.
ancestors = CategoryClosure.ancestors(scifi, include_node=True)
print [cat.name for cat in ancestors]
[u'Books', u'Fiction', u'Sci-fi']
[u'Books', u'Fiction', u'Sci-fi']
[u'Books', u'Fiction', u'Sci-fi']

Benchmarking performance

In this sample benchmark, we'll measure the performance of SQLite using materialized path (as a delimited string with LIKE queries) and transitive closures. The test will measure the performance of inserting rows into the table and querying for direct descendants and all descendants of a sample set of rows. We will test different size trees to see if the table structure has any impact on performance.

In [15]:
from collections import namedtuple
import operator
import os
import time

from peewee import *
from playhouse.sqlite_ext import *
from playhouse.test_utils import count_queries

# Transitive-closure database.
db_tc = SqliteExtDatabase('/tmp/bench-tc.db')
db_tc.load_extension('/home/charles/envs/notebook/closure')

# Materialized path database.
db_mp = SqliteExtDatabase('/tmp/bench-mp.db')
In [16]:
class CategoryTC(Model):
    name = CharField()
    parent = ForeignKeyField('self', index=True, null=True)
    depth = IntegerField(default=0)

    class Meta:
        database = db_tc

    def __unicode__(self):
        return self.name

    def save(self, *args, **kwargs):
        if self.parent:
            self.depth = self.parent.depth + 1
        return super(CategoryTC, self).save(*args, **kwargs)

    @classmethod
    def query_direct_children(cls, name):
        cat = cls.get(cls.name == name)
        return (cls
                .select()
                .join(CategoryClosure, on=(cls.id == CategoryClosure.id))
                .where((CategoryClosure.root == cat) & (CategoryClosure.depth == 1)))

    @classmethod
    def query_all_children(cls, name):
        cat = cls.get(cls.name == name)
        return (cls
                .select()
                .join(CategoryClosure, on=(cls.id == CategoryClosure.id))
                .where(CategoryClosure.root == cat))

CategoryClosure = ClosureTable(CategoryTC)
In [17]:
class CategoryMP(Model):
    name = CharField()
    path = CharField(index=True)
    parent = ForeignKeyField('self', index=True, null=True)
    depth = IntegerField(default=0)

    class Meta:
        database = db_mp

    def save(self, *args, **kwargs):
        if self.parent:
            self.path = self.parent.path + '.' + self.name
            self.depth = self.parent.depth + 1
        else:
            self.path = self.name
            self.depth = 0
        inst = super(CategoryMP, self).save(*args, **kwargs)

    @classmethod
    def query_direct_children(cls, name):
        cat = cls.get(cls.name == name)
        return cls.select().where(
            (cls.path.startswith(cat.path + '.')) &
            (cls.depth == cat.depth + 1))

    @classmethod
    def query_all_children(cls, name):
        cat = cls.get(cls.name == name)
        return cls.select().where(cls.path.startswith(cat.path + '.'))
In [18]:
# Helper class to measure the time it takes to run one or more operations.
class timed(object):
    def __init__(self):
        self.time = None

    def __enter__(self):
        self.start = time.time()
        return self

    def __exit__(self, *args, **kwargs):
        self.time = time.time() - self.start
In [19]:
# Helper function to construct a tree given a list of node counts
# describing the branching factor at each level of the tree.
def create_tree(model_class, node_counts):
    def build_tree(parent, idx):
        if idx == len(node_counts):
            return
        for i in range(node_counts[idx]):
            if parent:
                name = '%s.%s' % (parent.name, i)
            else:
                name = str(i)
            inst = model_class.create(name=name, parent=parent)
            build_tree(inst, idx + 1)

    build_tree(None, 0)
In [20]:
# Helper function to query the table for all descendants and direct
# descendants. The query is "consumed" by calling `list`, otherwise
# it would not be evaluated.
def query(model_class, names, iters=3):
    for name in names:
        for i in range(iters):
            list(model_class.query_all_children(name))
            list(model_class.query_direct_children(name))
In [21]:
# Helper function to generate a list of node names to query.
def make_names(structure):
    depth = len(structure)
    names = []
    for i in range(depth):
        n_at_depth = min(structure[i], 10)
        nums = ['0'] * (i + 1)
        names.append('.'.join(nums))
        for i in range(1, n_at_depth):
            nums[-1] = str(i)
            names.append('.'.join(nums))
    return names
In [22]:
# Showing how make_names works:
make_names([2, 2, 2])
Out[22]:
['0', '1', '0.0', '0.1', '0.0.0', '0.0.1']
In [23]:
# Helper class to store the results of a particular benchmark.
class Result(namedtuple('_R', ('structure', 'time', 'queries', 'timings'))):
    def __new__(cls, structure, time, queries):
        return super(Result, cls).__new__(cls, structure, time, queries, [])

    def __repr__(self):
        return '<Result: %s nodes, created in %.3f, avg %.3f>' % (
            self.nodes,
            self.time,
            self.avg)

    @property
    def nodes(self):
        return sum(reduce(operator.mul, self.structure[:i + 1])
                   for i in range(len(self.structure)))

    @property
    def avg(self):
        return sum(timing for _, timing in self.timings) / len(self.timings)

    def add_result(self, name, timing):
        depth = len(name.split('.'))
        self.timings.append((depth, timing))
In [24]:
# The actual benchmarking code.
def benchmark(model_class, structure):
    # Ensure the database connection is closed, if it is not already.
    if not db_tc.is_closed():
        db_tc.close()
    if not db_mp.is_closed():
        db_mp.close()

    # Delete any database files on disk so we are starting fresh each time.
    filenames = ['/tmp/bench-tc.db', '/tmp/bench-mp.db']
    for filename in filenames:
        if os.path.exists(filename):
            os.unlink(filename)

    # Re-create the tables.
    CategoryTC.create_table()
    CategoryClosure.create_table()
    CategoryMP.create_table()

    with timed() as timer:
        with count_queries() as cq:
            create_tree(model_class, structure)

    result = Result(structure, timer.time, cq.count)
    names = make_names(structure)

    for name in names:
        with timed() as timer:
            query(model_class, names)
        result.add_result(name, timer.time)

    return result
In [25]:
structures = [
    ('Deep', [3, 3, 3, 2, 2, 2, 2, 2, 2]),
    ('Very deep', [2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1]),
    ('Uniform', [5, 5, 5, 5, 3]),
    ('Top-heavy', [50, 5, 5, 1]),
    ('Bottom-heavy', [1, 1, 5, 10, 50]),
    ('Very wide', [1000, 1]),
]
In [26]:
# Disable logging while running the benchmark.
import logging
logging.disable(logging.DEBUG)
In [27]:
# Run the benchmark!
for label, structure in structures:
    print label    
    for model_class in [CategoryMP, CategoryTC]:
        result = benchmark(model_class, structure)
        print '%s: %r' % (model_class.__name__, result)
Deep
CategoryMP: <Result: 3441 nodes, created in 1.186, avg 0.498>
CategoryTC: <Result: 3441 nodes, created in 0.999, avg 0.440>
Very deep
CategoryMP: <Result: 2046 nodes, created in 0.672, avg 0.371>
CategoryTC: <Result: 2046 nodes, created in 0.589, avg 0.338>
Uniform
CategoryMP: <Result: 2655 nodes, created in 0.854, avg 0.334>
CategoryTC: <Result: 2655 nodes, created in 0.771, avg 0.315>
Top-heavy
CategoryMP: <Result: 2800 nodes, created in 0.917, avg 0.168>
CategoryTC: <Result: 2800 nodes, created in 0.812, avg 0.131>
Bottom-heavy
CategoryMP: <Result: 2557 nodes, created in 0.834, avg 0.703>
CategoryTC: <Result: 2557 nodes, created in 0.742, avg 0.653>
Very wide
CategoryMP: <Result: 2000 nodes, created in 0.646, avg 0.060>
CategoryTC: <Result: 2000 nodes, created in 0.611, avg 0.048>