#!/usr/bin/env python # coding: utf-8 # Convolutional Dictionary Learning # ================================= # # This example demonstrating the use of [dictlrn.DictLearn](http://sporco.rtfd.org/en/latest/modules/sporco.dictlrn.dictlrn.html#sporco.dictlrn.dictlrn.DictLearn) to construct a dictionary learning algorithm with the flexibility of choosing the sparse coding and dictionary update classes. In this case they are [cbpdn.ConvBPDNGradReg](http://sporco.rtfd.org/en/latest/modules/sporco.admm.cbpdn.html#sporco.admm.cbpdn.ConvBPDNGradReg) and [admm.ccmod.ConvCnstrMOD](http://sporco.rtfd.org/en/latest/modules/sporco.admm.ccmod.html#sporco.admm.ccmod.ConvCnstrMOD) respectively, so the resulting dictionary learning algorithm is not equivalent to [dictlrn.cbpdndl.ConvBPDNDictLearn](http://sporco.rtfd.org/en/latest/modules/sporco.dictlrn.cbpdndl.html#sporco.dictlrn.cbpdndl.ConvBPDNDictLearn). Sparse coding with a CBPDN variant that includes a gradient regularization term on one of the coefficient maps [[52]](http://sporco.rtfd.org/en/latest/zreferences.html#id55) enables CDL without the need for the usual lowpass/highpass filtering as a pre-processing of the training images. # In[1]: from __future__ import division from __future__ import print_function from builtins import input import pyfftw # See https://github.com/pyFFTW/pyFFTW/issues/40 import numpy as np from sporco.admm import cbpdn from sporco.admm import ccmod from sporco.dictlrn import dictlrn from sporco import cnvrep from sporco import util from sporco import plot plot.config_notebook_plotting() # Load training images. # In[2]: exim = util.ExampleImages(scaled=True, zoom=0.5, gray=True) img1 = exim.image('barbara.png', idxexp=np.s_[10:522, 100:612]) img2 = exim.image('kodim23.png', idxexp=np.s_[:, 60:572]) img3 = exim.image('monarch.png', idxexp=np.s_[:, 160:672]) S = np.stack((img1, img2, img3), axis=2) # Construct initial dictionary. # In[3]: np.random.seed(12345) D0 = np.random.randn(8, 8, 64) # Construct object representing problem dimensions. # In[4]: cri = cnvrep.CDU_ConvRepIndexing(D0.shape, S) # Set up weights for the $\ell_1$ norm to disable regularization of the coefficient map corresponding to the impulse filter. # In[5]: wl1 = np.ones((1,)*4 + (D0.shape[2:]), dtype=np.float32) wl1[..., 0] = 0.0 # Set of weights for the $\ell_2$ norm of the gradient to disable regularization of all coefficient maps except for the one corresponding to the impulse filter. # In[6]: wgr = np.zeros((D0.shape[2]), dtype=np.float32) wgr[0] = 1.0 # Define X and D update options. # In[7]: lmbda = 0.1 mu = 0.5 optx = cbpdn.ConvBPDNGradReg.Options({'Verbose': False, 'MaxMainIter': 1, 'rho': 20.0*lmbda + 0.5, 'AutoRho': {'Period': 10, 'AutoScaling': False, 'RsdlRatio': 10.0, 'Scaling': 2.0, 'RsdlTarget': 1.0}, 'HighMemSolve': True, 'AuxVarObj': False, 'L1Weight': wl1, 'GradWeight': wgr}) optd = ccmod.ConvCnstrMODOptions({'Verbose': False, 'MaxMainIter': 1, 'rho': 5.0*cri.K, 'AutoRho': {'Period': 10, 'AutoScaling': False, 'RsdlRatio': 10.0, 'Scaling': 2.0, 'RsdlTarget': 1.0}}, method='cns') # Normalise dictionary according to dictionary Y update options. # In[8]: D0n = cnvrep.Pcn(D0, D0.shape, cri.Nv, dimN=2, dimC=0, crp=True, zm=optd['ZeroMean']) # Update D update options to include initial values for Y and U. # In[9]: optd.update({'Y0': cnvrep.zpad(cnvrep.stdformD(D0n, cri.Cd, cri.M), cri.Nv), 'U0': np.zeros(cri.shpD + (cri.K,))}) # Create X update object. # In[10]: xstep = cbpdn.ConvBPDNGradReg(D0n, S, lmbda, mu, optx) # Create D update object. # In[11]: dstep = ccmod.ConvCnstrMOD(None, S, D0.shape, optd, method='cns') # Create DictLearn object and solve. # In[12]: opt = dictlrn.DictLearn.Options({'Verbose': True, 'MaxMainIter': 200}) d = dictlrn.DictLearn(xstep, dstep, opt) D1 = d.solve() print("DictLearn solve time: %.2fs" % d.timer.elapsed('solve'), "\n") # Display dictionaries. # In[13]: D1 = D1.squeeze() fig = plot.figure(figsize=(14, 7)) plot.subplot(1, 2, 1) plot.imview(util.tiledict(D0), title='D0', fig=fig) plot.subplot(1, 2, 2) plot.imview(util.tiledict(D1), title='D1', fig=fig) fig.show() # Plot functional value and residuals. # In[14]: itsx = xstep.getitstat() itsd = dstep.getitstat() fig = plot.figure(figsize=(20, 5)) plot.subplot(1, 3, 1) plot.plot(itsx.ObjFun, xlbl='Iterations', ylbl='Functional', fig=fig) plot.subplot(1, 3, 2) plot.plot(np.vstack((itsx.PrimalRsdl, itsx.DualRsdl, itsd.PrimalRsdl, itsd.DualRsdl)).T, ptyp='semilogy', xlbl='Iterations', ylbl='Residual', lgnd=['X Primal', 'X Dual', 'D Primal', 'D Dual'], fig=fig) plot.subplot(1, 3, 3) plot.plot(np.vstack((itsx.Rho, itsd.Rho)).T, xlbl='Iterations', ylbl='Penalty Parameter', ptyp='semilogy', lgnd=['Rho', 'Sigma'], fig=fig) fig.show()