Mixing Bloom Filters, MinHashes and SBT

  • Nodegraph can work as a regular Bloom Filter, if you ignore k-size and use count directly by sending integers. (This is limited to 64-bit ints, since that's what the implementation supports). It works with the values in the signatures created with Sourmash, so that's what I'm using for now.

  • Initial SBT implementation comes from https://github.com/dib-lab/2015-09-10-scihack, which was based on https://github.com/ctb/2015-sbt-demo. The scihack version adds better tree balancing (all leaves are always in the last or second-to-last level), saving and loading from disk, and methods for printing the tree (for easier debugging). I extended it by supporting different types of Leaf classes and moving from a binary to a n-ary tree (might work better with lots of Leaves)

  • I'm using the same signatures from the sourmash urchin demo.

In [1]:
cd -q ..
In [2]:
from collections import defaultdict
from glob import glob
import os
from functools import partial

from IPython.display import Image

from sourmash_lib import signature
from sbt import SBT, GraphFactory, Node
from sbtmh import search_minhashes, SigLeaf
In [3]:
factory = GraphFactory(31, 1e5, 4)

SBT methods

We'll use a reduced dataset to demonstrate the SBT methods (only lividus-* signatures, and a binary tree).

In [4]:
sig_to_search = "urchin/lividus-SRR1735497.sig"
with open(sig_to_search, 'r') as data:
    to_search = signature.load_signatures(data)[0]

We build trees by adding leaves with the add_node method, which takes care of positioning the leaf in the tree and updating internal nodes.

In [5]:
tree = SBT(factory)
for f in glob("urchin/lividus*.sig"):
    with open(f, 'r') as data:
        sig = signature.load_signatures(data)
    leaf = SigLeaf(os.path.basename(f), sig[0])
    tree.add_node(leaf)

Printing

There are two printing methods:

  • print, a simple ASCII tree
  • print_dot, which can be fed into GraphViz
In [6]:
tree.print()
 *Node:internal.0 [occupied: 1587, fpr: 6.4e-08]
     *Node:internal.2 [occupied: 1020, fpr: 1.1e-08]
         **Leaf:lividus-SRR1735498.sig -> lividus-SRR1735498.sig
         *Node:internal.5 [occupied: 787, fpr: 3.8e-09]
             **Leaf:lividus-SRR1735497.sig -> lividus-SRR1735497.sig
             **Leaf:lividus-SRR1735496.sig -> lividus-SRR1735496.sig
     *Node:internal.1 [occupied: 1207, fpr: 2.1e-08]
         *Node:internal.4 [occupied: 750, fpr: 3.2e-09]
             **Leaf:lividus-SRR1735499.sig -> lividus-SRR1735499.sig
             **Leaf:lividus-SRR1735501.sig -> lividus-SRR1735501.sig
         *Node:internal.3 [occupied: 848, fpr: 5.2e-09]
             **Leaf:lividus-SRR1735500.sig -> lividus-SRR1735500.sig
             **Leaf:lividus-SRR1664663.sig -> lividus-SRR1664663.sig
In [7]:
%%capture dag
tree.print_dot()
In [8]:
with open('dag.dot', 'w') as f:
    f.write(dag.stdout)
!dot -Tpng -Nshape=ellipse dag.dot > tree.png
Image("tree.png")
Out[8]:

Searching

The find method needs a search function. In our case it is the search_minhashes function, defined in sbtmh:

In [9]:
from inspect import getsource
print(getsource(search_minhashes))
def search_minhashes(node, sig, threshold, results=None):
    mins = sig.estimator.mh.get_mins()

    if isinstance(node, SigLeaf):
        matches = node.data.estimator.count_common(sig.estimator)
    else:  # Node or Leaf, Nodegraph by minhash comparison
        matches = sum(1 for value in mins if node.data.get(value))

    if results is not None:
        results[node.name] = matches / len(mins)

    if matches / len(mins) >= threshold:
        return 1
    return 0

There are two cases to consider: is the node a SigLeaf (another MinHash) or a Nodegraph (an internal node)? Both do the same thing (count how many values are in the intersection), but need to use the appropriate method from each class.

results can be passed to keep track of intermediary results (see which nodes were searched), but more about this later.

Finally, we can pass search_minhashes to the find method:

In [10]:
print('*' * 60)
print("{}:".format(sig_to_search))

filtered = tree.find(search_minhashes, to_search, 0.1)
matches = [(str(s.metadata), s.data.similarity(to_search))
            for s in filtered]

print(*matches, sep='\n')
************************************************************
urchin/lividus-SRR1735497.sig:
('lividus-SRR1735498.sig', 0.47999998927116394)
('lividus-SRR1664663.sig', 0.41600000858306885)
('lividus-SRR1735500.sig', 0.4059999883174896)
('lividus-SRR1735501.sig', 0.3619999885559082)
('lividus-SRR1735499.sig', 0.4580000042915344)
('lividus-SRR1735496.sig', 0.421999990940094)
('lividus-SRR1735497.sig', 1.0)

Saving and loading

In [11]:
tree.save('urchin')
Out[11]:
'urchin.sbt.json'
In [12]:
tree = SBT.load('urchin.sbt.json', leaf_loader=SigLeaf.load)
In [13]:
print('*' * 60)
print("{}:".format(sig_to_search))

load_filtered = tree.find(search_minhashes, to_search, 0.1)
load_matches = [(str(s.metadata), s.data.similarity(to_search))
                 for s in load_filtered]

print(*matches, sep='\n')
************************************************************
urchin/lividus-SRR1735497.sig:
('lividus-SRR1735498.sig', 0.47999998927116394)
('lividus-SRR1664663.sig', 0.41600000858306885)
('lividus-SRR1735500.sig', 0.4059999883174896)
('lividus-SRR1735501.sig', 0.3619999885559082)
('lividus-SRR1735499.sig', 0.4580000042915344)
('lividus-SRR1735496.sig', 0.421999990940094)
('lividus-SRR1735497.sig', 1.0)

And we can see results before saving the tree and after loading it are the same:

In [14]:
set(matches) == set(load_matches)
Out[14]:
True

n-ary

We'll use the urchin/purpuratus* signatures, since there are more of them than urchin/lividus* and so generate more pronounced differences in tree layout (while still easy enough to visualize without huge images, as is the case when using all the urchin signatures).

In [15]:
sig_to_search = "urchin/purpuratus-SRR1012313.sig"
with open(sig_to_search, 'r') as data:
    to_search = signature.load_signatures(data)[0]
In [16]:
trees = {}
for d in (2, 5, 10):
    trees[d] = SBT(factory, d=d)

We can read all signatures once and add them to each of tree (instead of re-reading signatures each time we build one tree):

In [17]:
for f in glob("urchin/purpuratus*.sig"):
    with open(f, 'r') as data:
        sig = signature.load_signatures(data)
    leaf = SigLeaf(os.path.basename(f), sig[0])
    for d in (2, 5, 10):
        trees[d].add_node(leaf)

2-ary

In [18]:
%%capture dag
trees[2].print_dot()
In [19]:
with open('dag.dot', 'w') as f:
    f.write(dag.stdout)
!twopi -Tpng -Granksep=5 dag.dot > tree.png
Image("tree.png")
Out[19]: