%%file matrixmodel.py
# -*- coding: utf-8 -*-
"""
This file defines a class for matrix forward operators that are useful for
iterative reconstruction algorithms.
Created on Tue Jul 26 2016
@author: U. S. Kamilov, 2016.
"""
import numpy as np
class MatrixOperator:
"""
Class for all matrix forward operators
"""
def __init__(self, H_matrix, sigSize):
"""
Class constructor
"""
self.H_matrix = H_matrix
self.sigSize = sigSize
self.L = np.linalg.norm(H_matrix, 2) ** 2
def apply(self, f):
"""
Apply the forward operator
"""
fvec = np.reshape(f, (f.size, 1))
z = np.dot(self.H_matrix, fvec)
return z
def applyAdjoint(self, z):
"""
Apply the adjoint of the forward operator
"""
fvec = np.dot(np.transpose(self.H_matrix), z)
f = np.reshape(fvec, self.sigSize)
return f
def getLipschitzConstant(self):
"""
Return the largest eigen value of HTH
"""
return (self.L)
Writing matrixmodel.py
%%file algorithm.py
# -*- coding: utf-8 -*-
"""
This file implements fast iterative shrinkage/thresholding algorithm (FISTA)
For more details on this code see corresponding paper:
U. S. Kamilov, "Minimizing Isotropic Total Variation without Subiterations,"
Proc. 3rd International Traveling Workshop on Interactions between Sparse models
and Technology (iTWIST 2016) (Aalborg, Denmark, August 24-26)
Created on Tue Jul 16 07:08:59 2016
@author: U. S. Kamilov, 2016.
"""
import numpy as np
def fistaEst(y, forwardObj, tau, numIter = 100, accelerate = True):
"""
Iterative reconstruction with FISTA
"""
# Lipschitz constant is used for obtaining the step-size
stepSize = 1/forwardObj.getLipschitzConstant()
# Size of the reconstructed image
sigSize = forwardObj.sigSize
# Define iterates
fhat = np.zeros(sigSize)
s = fhat
q = 1
# Tracks evolution of the energy functional
cost = np.zeros((numIter, 1))
# Iterate
for iter in range(numIter):
# Gradient step
fhatnext = s - stepSize*forwardObj.applyAdjoint(forwardObj.apply(s)-y)
# Proximal step
fhatnext = softThresh(fhatnext, stepSize*tau)
# Update the relaxation parameter
if accelerate:
qnext = 0.5*(1+np.sqrt(1+4*(q**2)))
else:
qnext = 1
# Relaxation step
s = fhatnext + ((q-1)/qnext)*(fhatnext-fhat)
# Compute cost
cost[iter] = evaluateCost(y, forwardObj, fhat, tau)
# Update variables
q = qnext
fhat = fhatnext
return fhat, cost
def softThresh(w, lam):
"""
Soft-thresholding function
"""
# norm along axis 3 of differences
abs_w = np.abs(w)
# compute shrinkage
shrinkFac = np.maximum(abs_w-lam, 0)
abs_w[abs_w <= 0] = 1 # to avoid division by zero
shrinkFac = shrinkFac/abs_w
# save output
what = shrinkFac * w
return what
def evaluateCost(y, forwardObj, f, tau):
"""
Evaluate the energy functional
"""
cost = 0.5*np.sum((y-forwardObj.apply(f))**2) + tau*np.sum(np.abs(f))
cost = cost/(0.5*np.sum(y ** 2))
return cost
Writing algorithm.py
%reset -f
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from scipy import misc
from matrixmodel import MatrixOperator
from algorithm import fistaEst
# Size of the signal
sigSize = (32, 32)
# Read and resize the Shepp-Logan phantom
f = np.random.binomial(1, 0.1, sigSize) * np.random.randn(sigSize[0], sigSize[1])
# Total number of pixels
numPixels = np.size(f)
# Plot image
plt.figure()
imgplot = plt.imshow(f, interpolation='nearest')
imgplot.set_cmap('gray')
plt.colorbar()
plt.axis("off")
plt.title("Sparse Signal")
plt.show()
# Define the measurement model and its transpose
# Number of measurements
numMeasurements = np.ceil(2*numPixels/3).astype(int)
# Measurement matrix
H_matrix = np.random.randn(numMeasurements, numPixels)
# Define a class object for easier manipulation
forwardObj = MatrixOperator(H_matrix, sigSize)
# Noise standard deviation
stdDev = 0.01
# Generate noise
noise = stdDev*np.random.randn(numMeasurements, 1)
# Generate measurements
y = forwardObj.apply(f) + noise
# Benchmark reconstruction
# Regularization parameter
tau = 1
# Total number of iterations
numIter = 100
# Reconstruct with ISTA
[fhatISTA, costISTA] = fistaEst(y, forwardObj, tau, numIter, accelerate=False)
# Reconstruct with FISTA
[fhatFISTA, costFISTA] = fistaEst(y, forwardObj, tau, numIter, accelerate=True)
t = np.linspace(1, numIter, numIter)
plt.figure()
plt.semilogy(t, costISTA, 'r-', t, costFISTA, 'b-')
plt.xlim(1, numIter)
plt.grid(True)
plt.legend(["ISTA", "FISTA"])
plt.title("Evolution of the Cost")
plt.xlabel("t")
plt.ylabel("Relative Cost")
plt.figure()
# Plot original
plt.subplot(1, 2, 1)
imgplot = plt.imshow(fhatISTA, interpolation='nearest')
imgplot.set_cmap('gray')
plt.axis("off")
plt.title("ISTA")
# Plot reconstruction
plt.subplot(1, 2, 2)
imgplot = plt.imshow(fhatFISTA, interpolation='nearest')
imgplot.set_cmap('gray')
plt.axis("off")
plt.title("FISTA")
plt.show()