#!/usr/bin/env python # coding: utf-8 # # Energy auto-encoder: learn from an image # ## Setup # In[ ]: get_ipython().run_line_magic('matplotlib', 'inline') import numpy as np import numpy.linalg as la import matplotlib.pyplot as plt from pyunlocbox import functions, solvers import time import matplotlib, pyunlocbox # For versions only. print('Software versions:') for pkg in [np, matplotlib, pyunlocbox]: print(' %s: %s' % (pkg.__name__, pkg.__version__)) # ## Hyper-parameters # * The $\lambda$ are the relative importance of each term in the composite objective function. # * The sparse code dimensionality $m$ should be greater than $n$ for an overcomplete representation but much smaller than $N$ to avoid over-fitting. # * $N_p$ denotes the size of the (square) patches. # In[ ]: l_d = 10 l_e = 10 l_g = 0 m = 200 Np = 12 # ## Data # The set of data vectors $X \in R^{n \times N}$ is given by patches extracted from a grayscale image. # In[ ]: image = matplotlib.image.imread('data/barbara.png') (Nx, Ny) = np.shape(image) plt.figure(figsize=(8,5)) plt.imshow(image, cmap='gray') plt.title('Source image') plt.show() X = np.zeros((Np**2, Nx*Ny/Np**2)) for x in np.arange(Nx/Np): for y in np.arange(Ny/Np): X[:,x*y] = image[x:x+Np, y:y+Np].reshape((Np**2,)) (n, N) = np.shape(X) print('N = %d samples with dimensionality n = %d (patches of %dx%d).' % (N, n, Np, Np)) # ## Algorithm # Given $X \in R^{n \times N}$, solve $\min\limits_{Z \in R^{m \times N}, D \in R^{n \times m}, E \in R^{m \times n}} \frac{\lambda_d}{2} \|X - DZ\|_F^2 + \frac{\lambda_e}{2} \|Z - EX\|_F^2 + \|Z\|_1$ s.t. $\|d_i\|_2 \leq 1$, $\|e_k\|_2 \leq 1$, $i = 1, \ldots, m$, $k = 1, \ldots, n$ # In[ ]: # Solver numeric parameters. N_outer = 20 rtol = 1e-3 # Static loss function definitions. g_z = functions.norm_l1() g_de = functions.proj_b2(epsilon=1) # L2-ball indicator function. # Initialization. Z = np.random.normal(size=(m, N)) D = np.random.normal(size=(n, m)) E = np.random.normal(size=(m, n)) objective_z = [] objective_d = [] objective_e = [] objective_g = [] tstart = time.time() # Multi-variate non-convex optimization (outer loop). for k in np.arange(N_outer): # Convex minimization for Z. f_zd = functions.norm_l2(lambda_=l_d/2., A=D, y=X, tight=False) f_ze = functions.norm_l2(lambda_=l_e/2., y=np.dot(E,X)) f_z = functions.func() f_z._eval = lambda Z: f_zd.eval(Z) + f_ze.eval(Z) f_z._grad = lambda Z: f_zd.grad(Z) + f_ze.grad(Z) L = l_e + l_d * la.norm(np.dot(D.T, D)) # Lipschitz continuous gradient. solver = solvers.forward_backward(step=1./L, method='FISTA') ret = solvers.solve([f_z, g_z], Z, solver, rtol=rtol, verbosity='NONE') Z = ret['sol'] objective_z.extend(ret['objective']) objective_d.extend(np.zeros(np.shape(ret['objective']))) objective_e.extend(np.zeros(np.shape(ret['objective']))) # Convex minimization for D. f_d = functions.norm_l2(lambda_=l_d/2., A=Z.T, y=X.T, tight=False) L = l_d * la.norm(np.dot(Z, Z.T)) # Lipschitz continuous gradient. solver = solvers.forward_backward(step=1./L, method='FISTA') ret = solvers.solve([f_d, g_de], D.T, solver, rtol=rtol, verbosity='NONE') D = ret['sol'].T objective_d.extend(ret['objective']) objective_z.extend(np.zeros(np.shape(ret['objective']))) objective_e.extend(np.zeros(np.shape(ret['objective']))) E = D.T # Convex minimization for E. f_e = functions.norm_l2(lambda_=l_e/2., A=X.T, y=Z.T, tight=False) L = l_e * la.norm(np.dot(X, X.T)) # Lipschitz continuous gradient. solver = solvers.forward_backward(step=1./L, method='FISTA') ret = solvers.solve([f_e, g_de], E.T, solver, rtol=rtol, verbosity='NONE') E = ret['sol'].T objective_e.extend(ret['objective']) objective_z.extend(np.zeros(np.shape(ret['objective']))) objective_d.extend(np.zeros(np.shape(ret['objective']))) D = E.T # Global objective (the indicators are 0). objective_g.append(g_z.eval(Z) + f_d.eval(D.T) + f_e.eval(E.T)) print('Elapsed time: %d seconds' % (time.time() - tstart)) # ## Convergence analysis # * Although the overall multi-variate problem is not convex, it seems to converge toward a solution. As we optimally solve each sub-problem (which are convex), we can guarantee that the global objective function will monotically decrease, which is indeed the case. # * Good news: the encoder seems to very well approximate the sparse code (low L2 reconstruction error) ! # In[ ]: print('g_z(Z) = %e' % g_z.eval(Z)) print('f_z(Z,D) = %e' % f_z.eval(Z)) print('f_d(D,Z) = %e' % f_d.eval(D.T)) print('f_e(E,Z) = %e' % f_e.eval(E.T)) print('g_z(Z) + f_d(D,Z) + f_e(E,Z) = %e' % objective_g[-1]) plt.figure(figsize=(8,5)) plt.semilogy(np.array(objective_z)[:, 0], label='Z: data term') plt.semilogy(np.array(objective_z)[:, 1], label='Z: prior term') #plt.semilogy(np.sum(objective[:,0:2], axis=1), label='Z: sum') plt.semilogy(np.array(objective_d)[:, 0], label='D: data term') plt.semilogy(np.array(objective_e)[:, 0], label='E: data term') N = np.shape(objective_z)[0] plt.xlim(0, N-1) plt.title('Sub-problems convergence') plt.xlabel('Iteration number (inner loops)') plt.ylabel('Objective function value') plt.grid(True); plt.legend(); plt.show() print('Inner loop: %d iterations' % N) plt.figure(figsize=(8,5)) plt.plot(objective_g) N = np.shape(objective_g)[0] plt.xlim(0, N-1) plt.title('Global convergence') plt.xlabel('Iteration number (outer loop)') plt.ylabel('Objective function value') plt.grid(True); plt.show() print('Outer loop: %d iterations\n' % N) # ## Solution analysis # ### Sparse codes # * They can be arbitrary sparse by decreasing $\lambda_d$ and $\lambda_e$. # In[ ]: nnz = np.count_nonzero(Z) #nnz = np.sum(np.abs(Z) < 1e-4) print('Sparsity of Z: %d non-zero entries out of %d entries, i.e. %.1f%%.' % (nnz, Z.size, 100.*nnz/Z.size)) plt.figure(figsize=(8,5)) plt.spy(Z, precision=0, aspect='auto') plt.xlabel('N = %d samples' % N) plt.ylabel('m = %d atoms' % m) plt.show() # ### Dictionary # * All constraints are indeed honored. # * Only few atoms are actually used. This can already be seen via the sparse code sparsity pattern. # * Learned atoms seem plausible but highly repetitive (keep in mind that the images are normalized). # In[ ]: d = np.sqrt(np.sum(D*D, axis=0)) print('Constraints on D: %s' % np.alltrue(d <= 1)) plt.figure(figsize=(8,5)) plt.semilogy(d, '.') plt.title('Dictionary atom norms') plt.xlabel('Atom [1,m]') plt.ylabel('Norm [0,1]') plt.grid(True); plt.show() plt.show() plt.figure(figsize=(8,5)) plt.spy(D, precision=1e-7) plt.xlabel('m = %d atoms' % m) plt.ylabel('data dimensionality of n = %d' % n) plt.show() #plt.scatter to show intensity # In[ ]: plt.figure(figsize=(8,8)) Nx = np.ceil(np.sqrt(m)) Ny = np.ceil(m / float(Nx)) for k in np.arange(m): plt.subplot(Ny, Nx, k) img = D[:,k].reshape(Np,Np) plt.imshow(img, cmap='gray') # vmin=0, vmax=1 to disable normalization. plt.axis('off') # ### Encoder # * All constraints are indeed honored. # * Depending on the conditions, the encoder ressembles the transpose of the dictionary, i.e. $E \approx D^T$. # In[ ]: e = np.sqrt(np.sum(E*E, axis=0)) print('Constraints on E: %s' % np.alltrue(e <= 1)) plt.figure(figsize=(8,5)) plt.semilogy(e) plt.title('Encoder norms') plt.xlabel('[1,n]') plt.ylabel('Norm [0,1]') plt.grid(True); plt.show() plt.show() plt.figure(figsize=(8,5)) plt.spy(E, precision=1e-7) plt.xlabel('data dimensionality of n = %d' % n) plt.ylabel('m = %d atoms' % m) plt.show()