from __future__ import print_function from builtins import input from builtins import range import pyfftw # See https://github.com/pyFFTW/pyFFTW/issues/40 import numpy as np from sporco import util from sporco import plot plot.config_notebook_plotting() import sporco.linalg as spl import sporco.metric as sm from sporco.admm import cbpdn def pad(x, n=8): if x.ndim == 2: return np.pad(x, n, mode='symmetric') else: return np.pad(x, ((n, n), (n, n), (0, 0)), mode='symmetric') def crop(x, n=8): return x[n:-n, n:-n] img = util.ExampleImages().image('monarch.png', zoom=0.5, scaled=True, gray=True, idxexp=np.s_[:, 160:672]) np.random.seed(12345) imgn = img + np.random.normal(0.0, 0.1, img.shape) npd = 16 fltlmbd = 5.0 imgnl, imgnh = util.tikhonov_filter(imgn, fltlmbd, npd) D = util.convdicts()['G:8x8x128'] imgnpl, imgnph = util.tikhonov_filter(pad(imgn), fltlmbd, npd) W = spl.irfftn(np.conj(spl.rfftn(D, imgnph.shape, (0, 1))) * spl.rfftn(imgnph[..., np.newaxis], None, (0, 1)), imgnph.shape, (0,1)) W = W**2 W = 1.0/(np.maximum(np.abs(W), 1e-8)) lmbda = 4.8e-2 opt = cbpdn.ConvBPDN.Options({'Verbose': True, 'MaxMainIter': 250, 'HighMemSolve': True, 'RelStopTol': 3e-3, 'AuxVarObj': False, 'L1Weight': W, 'AutoRho': {'Enabled': False}, 'rho': 4e2*lmbda}) b = cbpdn.ConvBPDN(D, pad(imgnh), lmbda, opt, dimK=0) X = b.solve() imgdp = b.reconstruct().squeeze() imgd = np.clip(crop(imgdp) + imgnl, 0, 1) print("ConvBPDN solve time: %5.2f s" % b.timer.elapsed('solve')) print("Noisy image PSNR: %5.2f dB" % sm.psnr(img, imgn)) print("Denoised image PSNR: %5.2f dB" % sm.psnr(img, imgd)) fig = plot.figure(figsize=(21, 7)) plot.subplot(1, 3, 1) plot.imview(img, title='Reference', fig=fig) plot.subplot(1, 3, 2) plot.imview(imgn, title='Noisy', fig=fig) plot.subplot(1, 3, 3) plot.imview(imgd, title='CSC Result', fig=fig) fig.show() its = b.getitstat() plot.plot(its.ObjFun, xlbl='Iterations', ylbl='Functional') plot.plot(np.vstack((its.PrimalRsdl, its.DualRsdl)).T, ptyp='semilogy', xlbl='Iterations', ylbl='Residual', lgnd=['Primal', 'Dual']) plot.plot(its.Rho, xlbl='Iterations', ylbl='Penalty Parameter')