#!/usr/bin/env python # coding: utf-8 # ### Examples of Radix sort # # Both least-significant-digit and most-significant-digit versions. Examples use DNA alphabet. # #### Least significant digit # Start with a function that conducts a single pass. # In[1]: from collections import defaultdict def radix_pass(strs, ordr, depth): """ Given a collection of same-length strings and depth, return a permutation that stably sorts the strings according to character at that depth """ buckets = defaultdict(list) for i in ordr: buckets[strs[i][depth]].append(i) return [x for sublist in [buckets[c] for c in '$ACGTn'] for x in sublist] # In[2]: strs = ['A', 'A', 'C', 'G', 'G', 'G', 'n'] radix_pass(strs, range(len(strs)), 0) # In[3]: strs = ['A', 'G', 'A', 'G', 'C', 'G', 'n'] radix_pass(strs, range(len(strs)), 0) # Chain two `radix_pass` calls together to get overall sorted order. # In[4]: # First call strs = ['AG', 'CG', 'AA', 'GA', 'TC', 'GT', 'Tn', 'nn', 'nC'] lsd1 = radix_pass(strs, range(len(strs)), 1) lsd1 # In[5]: # Second call, using result from first radix_pass(strs, lsd1, 0) # To completely sort the strings, `radix_lsd` does all the passes. # In[6]: def radix_lsd(strs): """ Least-significant-digit (LSD) radix sort on collection of same-length strings. Returns a permutation that puts the list in stable-sorted order. """ assert len(strs) > 0 ordr = range(len(strs)) for i in range(len(strs[0])-1, -1, -1): ordr = radix_pass(strs, ordr, i) return ordr # In[7]: strs = ['AG', 'CG', 'AA', 'GA', 'TC', 'GT', 'Tn', 'nn', 'nC'] radix_lsd(strs) # #### Most significant digit # MSD radix sort with a single recursive function. # In[8]: def radix_msd_helper(strs, ordr, depth): """ Most-significant-digit (MSD) radix sort on collection of same-length strings. Returns a permutation that puts the list in stable-sorted order. """ if len(ordr) <= 1 or depth >= len(strs[0]): return ordr # bases cases: 1 elt in list, or we've exhausted characters buckets = defaultdict(list) for i in ordr: buckets[strs[i][depth]].append(i) return [x for sublist in [radix_msd_helper(strs, buckets[c], depth+1) for c in '$ACGTn'] for x in sublist] def radix_msd(strs): return radix_msd_helper(strs, range(len(strs)), 0) # In[9]: strs = ['AG', 'CG', 'AA', 'GA', 'TC', 'GT', 'Tn', 'nn', 'nC'] radix_msd(strs) # In[10]: strs = ['GATTACA', 'GATTAAA', 'GAATACA', 'GATAACA', 'AATTAAA'] assert radix_msd(strs) == radix_lsd(strs) radix_msd(strs)