Examples of Radix sort

Both least-significant-digit and most-significant-digit versions. Examples use DNA alphabet.

Least significant digit

Start with a function that conducts a single pass.

In [1]:
from collections import defaultdict

def radix_pass(strs, ordr, depth):
    """ Given a collection of same-length strings and depth, return a
        permutation that stably sorts the strings according to character
        at that depth """
    buckets = defaultdict(list)
    for i in ordr:
        buckets[strs[i][depth]].append(i)
    return [x for sublist in [buckets[c] for c in '$ACGTn'] for x in sublist]
In [2]:
strs = ['A', 'A', 'C', 'G', 'G', 'G', 'n']
radix_pass(strs, range(len(strs)), 0)
Out[2]:
[0, 1, 2, 3, 4, 5, 6]
In [3]:
strs = ['A', 'G', 'A', 'G', 'C', 'G', 'n']
radix_pass(strs, range(len(strs)), 0)
Out[3]:
[0, 2, 4, 1, 3, 5, 6]

Chain two radix_pass calls together to get overall sorted order.

In [4]:
# First call
strs = ['AG', 'CG', 'AA', 'GA', 'TC', 'GT', 'Tn', 'nn', 'nC']
lsd1 = radix_pass(strs, range(len(strs)), 1)
lsd1
Out[4]:
[2, 3, 4, 8, 0, 1, 5, 6, 7]
In [5]:
# Second call, using result from first
radix_pass(strs, lsd1, 0)
Out[5]:
[2, 0, 1, 3, 5, 4, 6, 8, 7]

To completely sort the strings, radix_lsd does all the passes.

In [6]:
def radix_lsd(strs):
    """ Least-significant-digit (LSD) radix sort on collection of
        same-length strings.  Returns a permutation that puts the list
        in stable-sorted order. """
    assert len(strs) > 0
    ordr = range(len(strs))
    for i in range(len(strs[0])-1, -1, -1):
        ordr = radix_pass(strs, ordr, i)
    return ordr
In [7]:
strs = ['AG', 'CG', 'AA', 'GA', 'TC', 'GT', 'Tn', 'nn', 'nC']
radix_lsd(strs)
Out[7]:
[2, 0, 1, 3, 5, 4, 6, 8, 7]

Most significant digit

MSD radix sort with a single recursive function.

In [8]:
def radix_msd_helper(strs, ordr, depth):
    """ Most-significant-digit (MSD) radix sort on collection of
        same-length strings.  Returns a permutation that puts the list
        in stable-sorted order. """
    if len(ordr) <= 1 or depth >= len(strs[0]):
        return ordr  # bases cases: 1 elt in list, or we've exhausted characters
    buckets = defaultdict(list)
    for i in ordr:
        buckets[strs[i][depth]].append(i)
    return [x for sublist in [radix_msd_helper(strs, buckets[c], depth+1) for c in '$ACGTn'] for x in sublist]

def radix_msd(strs):
    return radix_msd_helper(strs, range(len(strs)), 0)
In [9]:
strs = ['AG', 'CG', 'AA', 'GA', 'TC', 'GT', 'Tn', 'nn', 'nC']
radix_msd(strs)
Out[9]:
[2, 0, 1, 3, 5, 4, 6, 8, 7]
In [10]:
strs = ['GATTACA', 'GATTAAA', 'GAATACA', 'GATAACA', 'AATTAAA']
assert radix_msd(strs) == radix_lsd(strs)
radix_msd(strs)
Out[10]:
[4, 2, 3, 1, 0]