Both least-significant-digit and most-significant-digit versions. Examples use DNA alphabet.
Start with a function that conducts a single pass.
from collections import defaultdict
def radix_pass(strs, ordr, depth):
""" Given a collection of same-length strings and depth, return a
permutation that stably sorts the strings according to character
at that depth """
buckets = defaultdict(list)
for i in ordr:
buckets[strs[i][depth]].append(i)
return [x for sublist in [buckets[c] for c in '$ACGTn'] for x in sublist]
strs = ['A', 'A', 'C', 'G', 'G', 'G', 'n']
radix_pass(strs, range(len(strs)), 0)
[0, 1, 2, 3, 4, 5, 6]
strs = ['A', 'G', 'A', 'G', 'C', 'G', 'n']
radix_pass(strs, range(len(strs)), 0)
[0, 2, 4, 1, 3, 5, 6]
Chain two radix_pass
calls together to get overall sorted order.
# First call
strs = ['AG', 'CG', 'AA', 'GA', 'TC', 'GT', 'Tn', 'nn', 'nC']
lsd1 = radix_pass(strs, range(len(strs)), 1)
lsd1
[2, 3, 4, 8, 0, 1, 5, 6, 7]
# Second call, using result from first
radix_pass(strs, lsd1, 0)
[2, 0, 1, 3, 5, 4, 6, 8, 7]
To completely sort the strings, radix_lsd
does all the passes.
def radix_lsd(strs):
""" Least-significant-digit (LSD) radix sort on collection of
same-length strings. Returns a permutation that puts the list
in stable-sorted order. """
assert len(strs) > 0
ordr = range(len(strs))
for i in range(len(strs[0])-1, -1, -1):
ordr = radix_pass(strs, ordr, i)
return ordr
strs = ['AG', 'CG', 'AA', 'GA', 'TC', 'GT', 'Tn', 'nn', 'nC']
radix_lsd(strs)
[2, 0, 1, 3, 5, 4, 6, 8, 7]
MSD radix sort with a single recursive function.
def radix_msd_helper(strs, ordr, depth):
""" Most-significant-digit (MSD) radix sort on collection of
same-length strings. Returns a permutation that puts the list
in stable-sorted order. """
if len(ordr) <= 1 or depth >= len(strs[0]):
return ordr # bases cases: 1 elt in list, or we've exhausted characters
buckets = defaultdict(list)
for i in ordr:
buckets[strs[i][depth]].append(i)
return [x for sublist in [radix_msd_helper(strs, buckets[c], depth+1) for c in '$ACGTn'] for x in sublist]
def radix_msd(strs):
return radix_msd_helper(strs, range(len(strs)), 0)
strs = ['AG', 'CG', 'AA', 'GA', 'TC', 'GT', 'Tn', 'nn', 'nC']
radix_msd(strs)
[2, 0, 1, 3, 5, 4, 6, 8, 7]
strs = ['GATTACA', 'GATTAAA', 'GAATACA', 'GATAACA', 'AATTAAA']
assert radix_msd(strs) == radix_lsd(strs)
radix_msd(strs)
[4, 2, 3, 1, 0]