Computation and comparision of the bispectrum and the rotational bispectrum

We show how to compute the bispectrum and the rotational bispectrum, as presented in the paper

  • Image processing in the semidiscrete group of rototranslations by D. Prandi, U. Boscain and J.-P. Gauthier.
In [1]:
import numpy as np
from numpy import fft
from numpy import linalg as LA
from scipy import ndimage
from scipy import signal
import matplotlib.pyplot as plt
import as cm
import os

%matplotlib inline

Auxiliary functions

In [2]:
def int2intvec(a):
    Auxiliary function to recover a vector with the digits of a 
    given integer (in inverse order)
    `a` : integer
    digit = a%10
    vec = np.array([digit],dtype=int)
    a = (a-digit)/10
    while a!=0:
        digit = a%10
        vec = np.append(vec,int(digit))
        a = (a-digit)/10
    return vec
In [3]:
ALPHABET7 = "0123456"
ALPHABET10 = "0123456789"

def base_encode(num, alphabet):
    Encode a number in Base X

    `num`: The number to encode
    if (str(num) == alphabet[0]):
        return int(0)
    arr = []
    base = len(alphabet)
    while num:
        rem = num % base
        num = num // base
    return int(''.join(arr))

def base7to10(num):
    Convert a number from base 10 to base 7
    `num`: The number to convert    
    arr = int2intvec(num)
    num = 0
    for i in range(len(arr)):
        num += arr[i]*(7**(i))
    return num
def base10to7(num):
    Convert a number from base 7 to base 10
    `num`: The number to convert    
    return base_encode(num, ALPHABET7)
In [4]:
def rgb2gray(rgb):
    Convert an image from RGB to grayscale
    `rgb`: The image to convert    
    r, g, b = rgb[:,:,0], rgb[:,:,1], rgb[:,:,2]
    gray = 0.2989 * r + 0.5870 * g + 0.1140 * b
    return gray
In [5]:
def oversampling(image, factor = 7):
    Oversample a grayscale image by a certain factor, dividing each
    pixel in factor*factor subpixels with the same intensity.
    `image`:  The image to oversample
    `factor`: The oversampling factor
    old_shape = image.shape
    new_shape = (factor*old_shape[0], factor*old_shape[1])
    new_image = np.zeros(new_shape, dtype = image.dtype)
    for i in range(old_shape[0]):
        for j in range(old_shape[1]):
            new_image[factor*i:factor*i+factor,factor*j:factor*j+factor] = image[i,j]*np.ones((factor,factor))
    return new_image

Spiral architecture implementation

Spiral architecture has been introduced by Sheridan in

The implementation with hyperpels that we use in the following is presented in

For a more detailed implementation, see the notebook Hexagonal grid.

We start by defining the centered hyperpel, which is defined on a 9x9 grid and is composed of 56 pixels. It has the shape

# o o x x x x x o o
# o x x x x x x x o
# o x x x x x x x o
# x x x x x x x x x
# x x x C x x x x x 
# o x x x x x x x o
# o x x x x x x x o
# o o x x x x x o o
In [6]:
# The centered hyperpel
hyperpel = np.array([\
                [-2,3],[-1,3], [0,3], [1,3], [2,3], [3,3], [4,3],\
                [-2,2],[-1,2], [0,2], [1,2], [2,2], [3,2], [4,2],\
                [-3,1],[-2,1],[-1,1], [0,1], [1,1], [2,1], [3,1], [4,1],[5,1],\
                [-3,0],[-2,0],[-1,0], [0,0], [1,0], [2,0], [3,0], [4,0],[5,0],\
                [-2,-1],[-1,-1], [0,-1], [1,-1], [2,-1], [3,-1], [4,-1],\
                [-2,-2],[-1,-2], [0,-2], [1,-2], [2,-2], [3,-2], [4,-2],\
                [-1,-3], [0,-3], [1,-3], [2,-3], [3,-3]])

hyperpel_sa = hyperpel - np.array([1,1])

We now compute, in sa2hex, the address of the center of the hyperpel corresponding to a certain spiral address.

In [7]:
def sa2hex(spiral_address):
    # Split the number in basic unit and call the auxiliary function
    # Here we reverse the order, so that the index corresponds to the 
    # decimal position
    digits = str(spiral_address)[::-1] 
    hex_address = np.array([0,0])
    for i in range(len(digits)):
        if int(digits[i])<0 or int(digits[i])>6:
            print("Invalid spiral address!")
        elif digits[i]!= '0':
            hex_address += sa2hex_aux(int(digits[i]),i)
    return hex_address
# This computes the row/column positions of the base cases,
# that is, in the form a*10^(zeros).
def sa2hex_aux(a, zeros):
    # Base cases
    if zeros == 0:
        if a == 0:
            return np.array([0,0])
        elif a == 1:
            return np.array([0,8])
        elif a == 2:
            return np.array([-7,4])
        elif a == 3:
            return np.array([-7,-4])
        elif a == 4:
            return np.array([0,-8])
        elif a == 5:
            return np.array([7,-4])
        elif a == 6:
            return np.array([7,4])
    return sa2hex_aux(a,zeros-1)+ 2*sa2hex_aux(a%6 +1,zeros-1)

Then, we compute the value of the hyperpel corresponding to the spiral address, by averaging the values on the subpixels.

In [8]:
def sa_value(oversampled_image,spiral_address):
    Computes the value of the hyperpel corresponding to the given
    spiral coordinate.
    hp = hyperpel_sa + sa2hex(spiral_address)
    val = 0.
    for i in range(56):
        val += oversampled_image[hp[i,0],hp[i,1]]
    return val/56

Spiral addition and multiplication

In [10]:
def spiral_add(a,b,mod=0):
    addition_table = [
    dig_a = int2intvec(a)
    dig_b = int2intvec(b) 
    if (dig_a<0).any() or (dig_a>7).any() \
      or (dig_b<0).any() or (dig_b>7).any():
        print("Invalid spiral address!")
    if len(dig_a) == 1 and len(dig_b)==1:
        return addition_table[a][b]
    if len(dig_a) < len(dig_b):
    elif len(dig_b) < len(dig_a):
    res = 0
    for i in range(len(dig_a)):
        if i == len(dig_a)-1:
            res += spiral_add(dig_a[i],dig_b[i])*(10**i)
            temp = spiral_add(dig_a[i],dig_b[i])
            res += (temp%10)*(10**i)
            carry_on = spiral_add(dig_a[i+1],(temp - temp%10)/10)
            dig_a[i+1] = str(carry_on)
    if mod!=0:
        return res%mod
    return res
In [11]:
def spiral_mult(a,b, mod=0):
    multiplication_table = [
    dig_a = int2intvec(a)
    dig_b = int2intvec(b) 
    if (dig_a<0).any() or (dig_a>7).any() \
      or (dig_b<0).any() or (dig_b>7).any():
        print("Invalid spiral address!")
    sa_mult = int(0)
    for i in range(len(dig_b)):
        for j in range(len(dig_a)):
            temp = multiplication_table[dig_a[j]][dig_b[i]]*(10**(i+j))
    if mod!=0:
        return sa_mult%mod
    return sa_mult

Computation of the bispectrum

We start by computing the vector $\omega_f(\lambda)$, where $\lambda$ is a certain spiral address.

In [12]:
def omegaf(fft_oversampled, sa):
    Evaluates the vector omegaf corresponding to the given 
    spiral address sa.
    `fft_oversampled`: the oversampled FFT of the image
    `sa`: the spiral address where to compute the vector
    omegaf = np.zeros(6, dtype=fft_oversampled.dtype)
    for i in range(1,7):
        omegaf[i-1] = sa_value(fft_oversampled,spiral_mult(sa,i))
    return omegaf

Then, we can compute the "generalized invariant" corresponding to $\lambda_1$, $\lambda_2$ and $\lambda_3$, starting from the FFT of the image. That is

$$ I^3_f(\lambda_1,\lambda_2,\lambda_3) = \langle\omega_f(\lambda_1)\odot\omega_f(\lambda_2),\omega_f(\lambda_3)\rangle. $$

In [13]:
def invariant(fft_oversampled, sa1,sa2,sa3):
    Evaluates the generalized invariant of f on sa1, sa2 and sa3
    `fft_oversampled`: the oversampled FFT of the image
    `sa1`, `sa2`, `sa3`: the spiral addresses where to compute the invariant
    omega1 = omegaf(fft_oversampled,sa1)
    omega2 = omegaf(fft_oversampled,sa2)
    omega3 = omegaf(fft_oversampled,sa3)
    # Attention: np.vdot uses the scalar product with the complex 
    # conjugation at the first place!
    return np.vdot(omega1*omega2,omega3)

Finally, this function computes the bispectrum (or the rotational bispectrum) corresponding to the spiral addresses in the following picture. Hexagonal pixels

In [14]:
def bispectral_inv(fft_oversampled_example, rotational = False):
    Computes the (rotational) bispectral invariants for any sa1 
    and any sa2 in the above picture.
    `fft_oversampled_example`: oversampled FFT of the image
    `rotational`: if True, we compute the rotational bispectrum
    if rotational == True:
        bispectrum = np.zeros(9**2*6,dtype = fft_oversampled_example.dtype)
        bispectrum = np.zeros(9**2,dtype = fft_oversampled_example.dtype)
    indexes = [0,1,10,11,12,13,14,15,16]
    count = 0
    for i in range(9):
        sa1 = indexes[i]
        sa1_base10 = base7to10(sa1)
        for k in range(9):
            sa2 = indexes[k]
            if rotational == True:
                for r in range(6):
                    sa2_rot = spiral_mult(sa2,r)
                    sa2_rot_base10 = base7to10(sa2_rot)
                    sa3 = base10to7(sa1_base10+sa2_rot_base10)
                    count += 1
                sa2_base10 = base7to10(sa2)
                sa3 = base10to7(sa1_base10+sa2_base10)
                count += 1
    return bispectrum

Some timing tests.

In [15]:
example =  1 - rgb2gray(plt.imread('./test-images/butterfly.png'))
fft_example = np.fft.fftshift(np.fft.fft2(example))
fft_oversampled_example = oversampling(fft_example)
In [16]:
1 loops, best of 3: 372 ms per loop
In [17]:
bispectral_inv(fft_oversampled_example, rotational=True)
1 loops, best of 3: 2.35 s per loop


Here we define various functions to batch test the images in the test folder.

In [18]:
folder = './test-images'
In [19]:
def evaluate_invariants(image, rot = False):
    Evaluates the invariants of the given image.
    `image`: the matrix representing the image (not oversampled)
    `rot`: if True we compute the rotational bispectrum
    # compute the normalized FFT
    fft = np.fft.fftshift(np.fft.fft2(image))
    fft /= fft / LA.norm(fft)
    # oversample it
    fft_oversampled = oversampling(fft)
    return bispectral_inv(fft_oversampled, rotational = rot)

Some timing tests.

In [20]:
1 loops, best of 3: 1.07 s per loop
In [21]:
evaluate_invariants(example, rot = True)
1 loops, best of 3: 3.09 s per loop
In [22]:
def bispectral_folder(folder_name = folder, rot = False): 
    Evaluates all the invariants of the images in the selected folder, 
    storing them in a dictionary with their names as keys.
    `folder_name`: path to the folder
    `rot`: if True we compute the rotational bispectrum
    # we store the results in a dictionary
    results = {}
    for filename in os.listdir(folder_name):
        infilename = os.path.join(folder_name, filename)
        if not os.path.isfile(infilename): 

        base, extension = os.path.splitext(infilename)
        if extension == '.png':
            test_img = 1 - rgb2gray(plt.imread(infilename))
            bispectrum = evaluate_invariants(test_img, rot = rot)
            results[os.path.splitext(filename)[0]] = bispectrum
    return results
In [24]:
def bispectral_comparison(bispectrums, comparison = 'triangle', plot = True, log_scale = True):
    Returns the difference of the norms of the given invariants w.r.t. the 
    comparison element.
    `bispectrums`: a dictionary with as keys the names of the images and 
                    as values their invariants
    `comparison`:  the element to use as comparison
    if comparison not in bispectrums:
        print("The requested comparison is not in the folder")    
    bispectrum_diff = {}
    for elem in bispectrums:
        diff = LA.norm(bispectrums[elem]-bispectrums[comparison])
        # we remove nan results
        if not np.isnan(diff):
            bispectrum_diff[elem] = diff
    return bispectrum_diff
In [25]:
def bispectral_plot(bispectrums, comparison = 'triangle', log_scale = True):
    Plots the difference of the norms of the given invariants w.r.t. the 
    comparison element (by default in logarithmic scale).
    `bispectrums`: a dictionary with as keys the names of the images and 
                    as values their invariants
    `comparison`:  the element to use as comparison
    `log_scale`:   wheter the plot should be in log_scale
    bispectrum_diff = bispectral_comparison(bispectrums, comparison = comparison)

    if log_scale == True:
    for i in range(len(bispectrum_diff.values())):
        # if we plot in log scale, we do not put labels on items that are
        # too small, otherwise they exit the plot area.
        if log_scale and bispectrum_diff.values()[i] < 10**(-3):
        plt.title("Comparison with as reference '"+ comparison +"'")

Construction of the table for the paper

In [26]:
comparisons_paper = ['triangle', 'rectangle', 'ellipse', 'etoile', 'diamond']

def extract_table_values(bispectrums, comparisons = comparisons_paper):
    Extract the values for the table of the paper.
    `bispectrums`: a dictionary with as keys the names of the images and 
                    as values their invariants
    `comparison`:  list of elements to use as comparison
    Returns a list of tuples. Each tuple contains the name of the comparison 
    element, the maximal value of the difference of the norm of the invariants 
    with its rotated and the minimal values of the same difference with the 
    other images.
    table_values = []
    for elem in comparisons:
        diff = bispectral_comparison(bispectrums, comparison= elem, plot=False)

        l = len(elem)
        match = [x for x in diff.keys() if x[:l]==elem]
        not_match = [x for x in diff.keys() if x[:l]!=elem]

        max_match = max([ diff[k] for k in match ])
        min_not_match = min([ diff[k] for k in not_match ])
        table_values.append((elem,'%.2E' % (max_match),'%.2E' % min_not_match))
    return table_values
In [23]:
bispectrums = bispectral_folder()
bispectrums_rotational = bispectral_folder(rot=True)
/usr/local/lib/python2.7/site-packages/IPython/kernel/ RuntimeWarning: invalid value encountered in divide
In [27]:
[('triangle', '9.23E+10', '7.03E+12'),
 ('rectangle', '7.93E+10', '8.22E+12'),
 ('ellipse', '7.42E+10', '7.11E+12'),
 ('etoile', '7.27E+10', '5.54E+12'),
 ('diamond', '3.78E+10', '5.47E+12')]
In [28]:
[('triangle', '2.26E+11', '1.72E+13'),
 ('rectangle', '1.94E+11', '2.01E+13'),
 ('ellipse', '1.82E+11', '1.74E+13'),
 ('etoile', '1.78E+11', '1.36E+13'),
 ('diamond', '9.27E+10', '1.34E+13')]