#!/usr/bin/env python # coding: utf-8 # # Computation and comparision of the bispectrum and the rotational bispectrum # We show how to compute the bispectrum and the rotational bispectrum, as presented in the paper # # - _Image processing in the semidiscrete group of rototranslations_ by D. Prandi, U. Boscain and J.-P. Gauthier. # In[1]: import numpy as np from numpy import fft from numpy import linalg as LA from scipy import ndimage from scipy import signal import matplotlib.pyplot as plt import matplotlib.cm as cm import os get_ipython().run_line_magic('matplotlib', 'inline') # ## Auxiliary functions # In[2]: def int2intvec(a): """ Auxiliary function to recover a vector with the digits of a given integer (in inverse order) `a` : integer """ digit = a%10 vec = np.array([digit],dtype=int) a = (a-digit)/10 while a!=0: digit = a%10 vec = np.append(vec,int(digit)) a = (a-digit)/10 return vec # In[3]: ALPHABET7 = "0123456" ALPHABET10 = "0123456789" def base_encode(num, alphabet): """ Encode a number in Base X `num`: The number to encode """ if (str(num) == alphabet[0]): return int(0) arr = [] base = len(alphabet) while num: rem = num % base num = num // base arr.append(alphabet[rem]) arr.reverse() return int(''.join(arr)) def base7to10(num): """ Convert a number from base 10 to base 7 `num`: The number to convert """ arr = int2intvec(num) num = 0 for i in range(len(arr)): num += arr[i]*(7**(i)) return num def base10to7(num): """ Convert a number from base 7 to base 10 `num`: The number to convert """ return base_encode(num, ALPHABET7) # In[4]: def rgb2gray(rgb): """ Convert an image from RGB to grayscale `rgb`: The image to convert """ r, g, b = rgb[:,:,0], rgb[:,:,1], rgb[:,:,2] gray = 0.2989 * r + 0.5870 * g + 0.1140 * b return gray # In[5]: def oversampling(image, factor = 7): """ Oversample a grayscale image by a certain factor, dividing each pixel in factor*factor subpixels with the same intensity. `image`: The image to oversample `factor`: The oversampling factor """ old_shape = image.shape new_shape = (factor*old_shape[0], factor*old_shape[1]) new_image = np.zeros(new_shape, dtype = image.dtype) for i in range(old_shape[0]): for j in range(old_shape[1]): new_image[factor*i:factor*i+factor,factor*j:factor*j+factor] = image[i,j]*np.ones((factor,factor)) return new_image # ## Spiral architecture implementation # Spiral architecture has been introduced by Sheridan in # # - [_Spiral Architecture for Machine Vision_](https://opus.lib.uts.edu.au/research/handle/2100/280), PhD thesis # - [_Pseudo-invariant image transformations on a hexagonal lattice_](http://www.sciencedirect.com/science/article/pii/S0262885600000366), P. Sheridan, T. Hintz, and D. Alexander, Image Vis. Comput. 18, 907 (2000). # # The implementation with hyperpels that we use in the following is presented in # # - [_A New Simulation of Spiral Architecture_](http://ww1.ucmss.com/books/LFS/CSREA2006/IPC8173.pdf), X. He, T. Hintz, Q. Wu, H. Wang, and W. Jia, Proceedings of International Conference on Image Processing, Computer Vision, and Pattern Recognition (2006). # - [_Hexagonal structure for intelligent vision_](http://ieeexplore.ieee.org/xpl/login.jsp?tp=&arnumber=1598543&url=http%3A%2F%2Fieeexplore.ieee.org%2Fiel5%2F10652%2F33619%2F01598543), X. He and W. Jia, in Proc. 1st Int. Conf. Inf. Commun. Technol. ICICT 2005 (2005), pp. 52–64. # # For a more detailed implementation, see the notebook [Hexagonal grid](http://nbviewer.ipython.org/github/dprn/dprn.github.io/blob/master/docs/notebooks/Hexagonal%20Grid.ipynb). # We start by defining the centered hyperpel, which is defined on a 9x9 grid and is composed of 56 pixels. It has the shape # # o o x x x x x o o # # o x x x x x x x o # # o x x x x x x x o # # x x x x x x x x x # # x x x C x x x x x # # o x x x x x x x o # # o x x x x x x x o # # o o x x x x x o o # In[6]: # The centered hyperpel hyperpel = np.array([\ [-1,4],[0,4],[1,4],[2,4],[3,4],\ [-2,3],[-1,3], [0,3], [1,3], [2,3], [3,3], [4,3],\ [-2,2],[-1,2], [0,2], [1,2], [2,2], [3,2], [4,2],\ [-3,1],[-2,1],[-1,1], [0,1], [1,1], [2,1], [3,1], [4,1],[5,1],\ [-3,0],[-2,0],[-1,0], [0,0], [1,0], [2,0], [3,0], [4,0],[5,0],\ [-2,-1],[-1,-1], [0,-1], [1,-1], [2,-1], [3,-1], [4,-1],\ [-2,-2],[-1,-2], [0,-2], [1,-2], [2,-2], [3,-2], [4,-2],\ [-1,-3], [0,-3], [1,-3], [2,-3], [3,-3]]) hyperpel_sa = hyperpel - np.array([1,1]) # We now compute, in `sa2hex`, the address of the center of the hyperpel corresponding to a certain spiral address. # In[7]: def sa2hex(spiral_address): # Split the number in basic unit and call the auxiliary function # Here we reverse the order, so that the index corresponds to the # decimal position digits = str(spiral_address)[::-1] hex_address = np.array([0,0]) for i in range(len(digits)): if int(digits[i])<0 or int(digits[i])>6: print("Invalid spiral address!") return elif digits[i]!= '0': hex_address += sa2hex_aux(int(digits[i]),i) return hex_address # This computes the row/column positions of the base cases, # that is, in the form a*10^(zeros). def sa2hex_aux(a, zeros): # Base cases if zeros == 0: if a == 0: return np.array([0,0]) elif a == 1: return np.array([0,8]) elif a == 2: return np.array([-7,4]) elif a == 3: return np.array([-7,-4]) elif a == 4: return np.array([0,-8]) elif a == 5: return np.array([7,-4]) elif a == 6: return np.array([7,4]) return sa2hex_aux(a,zeros-1)+ 2*sa2hex_aux(a%6 +1,zeros-1) # Then, we compute the value of the hyperpel corresponding to the spiral address, by averaging the values on the subpixels. # In[8]: def sa_value(oversampled_image,spiral_address): """ Computes the value of the hyperpel corresponding to the given spiral coordinate. """ hp = hyperpel_sa + sa2hex(spiral_address) val = 0. for i in range(56): val += oversampled_image[hp[i,0],hp[i,1]] return val/56 # ### Spiral addition and multiplication # In[10]: def spiral_add(a,b,mod=0): addition_table = [ [0,1,2,3,4,5,6], [1,63,15,2,0,6,64], [2,15,14,26,3,0,1], [3,2,26,25,31,4,0], [4,0,3,31,36,42,5], [5,6,0,4,42,41,53], [6,64,1,0,5,53,52] ] dig_a = int2intvec(a) dig_b = int2intvec(b) if (dig_a<0).any() or (dig_a>7).any() \ or (dig_b<0).any() or (dig_b>7).any(): print("Invalid spiral address!") return if len(dig_a) == 1 and len(dig_b)==1: return addition_table[a][b] if len(dig_a) < len(dig_b): dig_a.resize(len(dig_b)) elif len(dig_b) < len(dig_a): dig_b.resize(len(dig_a)) res = 0 for i in range(len(dig_a)): if i == len(dig_a)-1: res += spiral_add(dig_a[i],dig_b[i])*(10**i) else: temp = spiral_add(dig_a[i],dig_b[i]) res += (temp%10)*(10**i) carry_on = spiral_add(dig_a[i+1],(temp - temp%10)/10) dig_a[i+1] = str(carry_on) if mod!=0: return res%mod return res # In[11]: def spiral_mult(a,b, mod=0): multiplication_table = [ [0,0,0,0,0,0,0], [0,1,2,3,4,5,6], [0,2,3,4,5,6,1], [0,3,4,5,6,1,2], [0,4,5,6,1,2,3], [0,5,6,1,2,3,4], [0,6,1,2,3,4,5], ] dig_a = int2intvec(a) dig_b = int2intvec(b) if (dig_a<0).any() or (dig_a>7).any() \ or (dig_b<0).any() or (dig_b>7).any(): print("Invalid spiral address!") return sa_mult = int(0) for i in range(len(dig_b)): for j in range(len(dig_a)): temp = multiplication_table[dig_a[j]][dig_b[i]]*(10**(i+j)) sa_mult=spiral_add(sa_mult,temp) if mod!=0: return sa_mult%mod return sa_mult # ## Computation of the bispectrum # We start by computing the vector $\omega_f(\lambda)$, where $\lambda$ is a certain spiral address. # In[12]: def omegaf(fft_oversampled, sa): """ Evaluates the vector omegaf corresponding to the given spiral address sa. `fft_oversampled`: the oversampled FFT of the image `sa`: the spiral address where to compute the vector """ omegaf = np.zeros(6, dtype=fft_oversampled.dtype) for i in range(1,7): omegaf[i-1] = sa_value(fft_oversampled,spiral_mult(sa,i)) return omegaf # Then, we can compute the "generalized invariant" corresponding to $\lambda_1$, $\lambda_2$ and $\lambda_3$, starting from the FFT of the image. # That is # # $$ # I^3_f(\lambda_1,\lambda_2,\lambda_3) = \langle\omega_f(\lambda_1)\odot\omega_f(\lambda_2),\omega_f(\lambda_3)\rangle. # $$ # In[13]: def invariant(fft_oversampled, sa1,sa2,sa3): """ Evaluates the generalized invariant of f on sa1, sa2 and sa3 `fft_oversampled`: the oversampled FFT of the image `sa1`, `sa2`, `sa3`: the spiral addresses where to compute the invariant """ omega1 = omegaf(fft_oversampled,sa1) omega2 = omegaf(fft_oversampled,sa2) omega3 = omegaf(fft_oversampled,sa3) # Attention: np.vdot uses the scalar product with the complex # conjugation at the first place! return np.vdot(omega1*omega2,omega3) # Finally, this function computes the bispectrum (or the rotational bispectrum) corresponding to the spiral addresses in the following picture. # Hexagonal pixels # In[14]: def bispectral_inv(fft_oversampled_example, rotational = False): """ Computes the (rotational) bispectral invariants for any sa1 and any sa2 in the above picture. `fft_oversampled_example`: oversampled FFT of the image `rotational`: if True, we compute the rotational bispectrum """ if rotational == True: bispectrum = np.zeros(9**2*6,dtype = fft_oversampled_example.dtype) else: bispectrum = np.zeros(9**2,dtype = fft_oversampled_example.dtype) indexes = [0,1,10,11,12,13,14,15,16] count = 0 for i in range(9): sa1 = indexes[i] sa1_base10 = base7to10(sa1) for k in range(9): sa2 = indexes[k] if rotational == True: for r in range(6): sa2_rot = spiral_mult(sa2,r) sa2_rot_base10 = base7to10(sa2_rot) sa3 = base10to7(sa1_base10+sa2_rot_base10) bispectrum[count]=invariant(fft_oversampled_example,sa1,sa2,sa3) count += 1 else: sa2_base10 = base7to10(sa2) sa3 = base10to7(sa1_base10+sa2_base10) bispectrum[count]=invariant(fft_oversampled_example,sa1,sa2,sa3) count += 1 return bispectrum # Some timing tests. # In[15]: example = 1 - rgb2gray(plt.imread('./test-images/butterfly.png')) fft_example = np.fft.fftshift(np.fft.fft2(example)) fft_oversampled_example = oversampling(fft_example) # In[16]: get_ipython().run_cell_magic('timeit', '', 'bispectral_inv(fft_oversampled_example)\n') # In[17]: get_ipython().run_cell_magic('timeit', '', 'bispectral_inv(fft_oversampled_example, rotational=True)\n') # ## Tests # Here we define various functions to batch test the images in the test folder. # In[18]: folder = './test-images' # In[19]: def evaluate_invariants(image, rot = False): """ Evaluates the invariants of the given image. `image`: the matrix representing the image (not oversampled) `rot`: if True we compute the rotational bispectrum """ # compute the normalized FFT fft = np.fft.fftshift(np.fft.fft2(image)) fft /= fft / LA.norm(fft) # oversample it fft_oversampled = oversampling(fft) return bispectral_inv(fft_oversampled, rotational = rot) # Some timing tests. # In[20]: get_ipython().run_cell_magic('timeit', '', 'evaluate_invariants(example)\n') # In[21]: get_ipython().run_cell_magic('timeit', '', 'evaluate_invariants(example, rot = True)\n') # In[22]: def bispectral_folder(folder_name = folder, rot = False): """ Evaluates all the invariants of the images in the selected folder, storing them in a dictionary with their names as keys. `folder_name`: path to the folder `rot`: if True we compute the rotational bispectrum """ # we store the results in a dictionary results = {} for filename in os.listdir(folder_name): infilename = os.path.join(folder_name, filename) if not os.path.isfile(infilename): continue base, extension = os.path.splitext(infilename) if extension == '.png': test_img = 1 - rgb2gray(plt.imread(infilename)) bispectrum = evaluate_invariants(test_img, rot = rot) results[os.path.splitext(filename)[0]] = bispectrum return results # In[24]: def bispectral_comparison(bispectrums, comparison = 'triangle', plot = True, log_scale = True): """ Returns the difference of the norms of the given invariants w.r.t. the comparison element. `bispectrums`: a dictionary with as keys the names of the images and as values their invariants `comparison`: the element to use as comparison """ if comparison not in bispectrums: print("The requested comparison is not in the folder") return bispectrum_diff = {} for elem in bispectrums: diff = LA.norm(bispectrums[elem]-bispectrums[comparison]) # we remove nan results if not np.isnan(diff): bispectrum_diff[elem] = diff return bispectrum_diff # In[25]: def bispectral_plot(bispectrums, comparison = 'triangle', log_scale = True): """ Plots the difference of the norms of the given invariants w.r.t. the comparison element (by default in logarithmic scale). `bispectrums`: a dictionary with as keys the names of the images and as values their invariants `comparison`: the element to use as comparison `log_scale`: wheter the plot should be in log_scale """ bispectrum_diff = bispectral_comparison(bispectrums, comparison = comparison) plt.plot(bispectrum_diff.values(),'ro') if log_scale == True: plt.yscale('log') for i in range(len(bispectrum_diff.values())): # if we plot in log scale, we do not put labels on items that are # too small, otherwise they exit the plot area. if log_scale and bispectrum_diff.values()[i] < 10**(-3): continue plt.text(i,bispectrum_diff.values()[i],bispectrum_diff.keys()[i][:3]) plt.title("Comparison with as reference '"+ comparison +"'") # ### Construction of the table for the paper # In[26]: comparisons_paper = ['triangle', 'rectangle', 'ellipse', 'etoile', 'diamond'] def extract_table_values(bispectrums, comparisons = comparisons_paper): """ Extract the values for the table of the paper. `bispectrums`: a dictionary with as keys the names of the images and as values their invariants `comparison`: list of elements to use as comparison Returns a list of tuples. Each tuple contains the name of the comparison element, the maximal value of the difference of the norm of the invariants with its rotated and the minimal values of the same difference with the other images. """ table_values = [] for elem in comparisons: diff = bispectral_comparison(bispectrums, comparison= elem, plot=False) l = len(elem) match = [x for x in diff.keys() if x[:l]==elem] not_match = [x for x in diff.keys() if x[:l]!=elem] max_match = max([ diff[k] for k in match ]) min_not_match = min([ diff[k] for k in not_match ]) table_values.append((elem,'%.2E' % (max_match),'%.2E' % min_not_match)) return table_values # In[23]: bispectrums = bispectral_folder() bispectrums_rotational = bispectral_folder(rot=True) # In[27]: extract_table_values(bispectrums) # In[28]: extract_table_values(bispectrums_rotational)