Energy auto-encoder: learn from an image


In [ ]:
%matplotlib inline
import numpy as np
import numpy.linalg as la
import matplotlib.pyplot as plt
from pyunlocbox import functions, solvers
import time
import matplotlib, pyunlocbox  # For versions only.
print('Software versions:')
for pkg in [np, matplotlib, pyunlocbox]:
    print('  %s: %s' % (pkg.__name__, pkg.__version__))


  • The $\lambda$ are the relative importance of each term in the composite objective function.
  • The sparse code dimensionality $m$ should be greater than $n$ for an overcomplete representation but much smaller than $N$ to avoid over-fitting.
  • $N_p$ denotes the size of the (square) patches.
In [ ]:
l_d = 10
l_e = 10
l_g = 0
m = 200
Np = 12


The set of data vectors $X \in R^{n \times N}$ is given by patches extracted from a grayscale image.

In [ ]:
image = matplotlib.image.imread('data/barbara.png')
(Nx, Ny) = np.shape(image)

plt.imshow(image, cmap='gray')
plt.title('Source image')

X = np.zeros((Np**2, Nx*Ny/Np**2))
for x in np.arange(Nx/Np):
    for y in np.arange(Ny/Np):
        X[:,x*y] = image[x:x+Np, y:y+Np].reshape((Np**2,))
(n, N) = np.shape(X)

print('N = %d samples with dimensionality n = %d (patches of %dx%d).' % (N, n, Np, Np))


Given $X \in R^{n \times N}$, solve $\min\limits_{Z \in R^{m \times N}, D \in R^{n \times m}, E \in R^{m \times n}} \frac{\lambda_d}{2} \|X - DZ\|_F^2 + \frac{\lambda_e}{2} \|Z - EX\|_F^2 + \|Z\|_1$ s.t. $\|d_i\|_2 \leq 1$, $\|e_k\|_2 \leq 1$, $i = 1, \ldots, m$, $k = 1, \ldots, n$

In [ ]:
# Solver numeric parameters.
N_outer = 20
rtol = 1e-3

# Static loss function definitions.
g_z = functions.norm_l1()
g_de = functions.proj_b2(epsilon=1)  # L2-ball indicator function.

# Initialization.
Z = np.random.normal(size=(m, N))
D = np.random.normal(size=(n, m))
E = np.random.normal(size=(m, n))
objective_z = []
objective_d = []
objective_e = []
objective_g = []
tstart = time.time()

# Multi-variate non-convex optimization (outer loop).
for k in np.arange(N_outer):
    # Convex minimization for Z.
    f_zd = functions.norm_l2(lambda_=l_d/2., A=D, y=X, tight=False)
    f_ze = functions.norm_l2(lambda_=l_e/2.,,X))
    f_z = functions.func()
    f_z._eval = lambda Z: f_zd.eval(Z) + f_ze.eval(Z)
    f_z._grad = lambda Z: f_zd.grad(Z) + f_ze.grad(Z)
    L = l_e + l_d * la.norm(, D))  # Lipschitz continuous gradient.
    solver = solvers.forward_backward(step=1./L, method='FISTA')
    ret = solvers.solve([f_z, g_z], Z, solver, rtol=rtol, verbosity='NONE')
    Z = ret['sol']
    # Convex minimization for D.
    f_d = functions.norm_l2(lambda_=l_d/2., A=Z.T, y=X.T, tight=False)
    L = l_d * la.norm(, Z.T))  # Lipschitz continuous gradient.
    solver = solvers.forward_backward(step=1./L, method='FISTA')
    ret = solvers.solve([f_d, g_de], D.T, solver, rtol=rtol, verbosity='NONE')
    D = ret['sol'].T
    E = D.T
    # Convex minimization for E.
    f_e = functions.norm_l2(lambda_=l_e/2., A=X.T, y=Z.T, tight=False)
    L = l_e * la.norm(, X.T))  # Lipschitz continuous gradient.
    solver = solvers.forward_backward(step=1./L, method='FISTA')
    ret = solvers.solve([f_e, g_de], E.T, solver, rtol=rtol, verbosity='NONE')
    E = ret['sol'].T
    D = E.T
    # Global objective (the indicators are 0).
    objective_g.append(g_z.eval(Z) + f_d.eval(D.T) + f_e.eval(E.T))

print('Elapsed time: %d seconds' % (time.time() - tstart))

Convergence analysis

  • Although the overall multi-variate problem is not convex, it seems to converge toward a solution. As we optimally solve each sub-problem (which are convex), we can guarantee that the global objective function will monotically decrease, which is indeed the case.
  • Good news: the encoder seems to very well approximate the sparse code (low L2 reconstruction error) !
In [ ]:
print('g_z(Z) = %e' % g_z.eval(Z))
print('f_z(Z,D) = %e' % f_z.eval(Z))
print('f_d(D,Z) = %e' % f_d.eval(D.T))
print('f_e(E,Z) = %e' % f_e.eval(E.T))
print('g_z(Z) + f_d(D,Z) + f_e(E,Z) = %e' % objective_g[-1])

plt.semilogy(np.array(objective_z)[:, 0], label='Z: data term')
plt.semilogy(np.array(objective_z)[:, 1], label='Z: prior term')
#plt.semilogy(np.sum(objective[:,0:2], axis=1), label='Z: sum')
plt.semilogy(np.array(objective_d)[:, 0], label='D: data term')
plt.semilogy(np.array(objective_e)[:, 0], label='E: data term')
N = np.shape(objective_z)[0]
plt.xlim(0, N-1)
plt.title('Sub-problems convergence')
plt.xlabel('Iteration number (inner loops)')
plt.ylabel('Objective function value')
plt.grid(True); plt.legend();
print('Inner loop: %d iterations' % N)

N = np.shape(objective_g)[0]
plt.xlim(0, N-1)
plt.title('Global convergence')
plt.xlabel('Iteration number (outer loop)')
plt.ylabel('Objective function value')
print('Outer loop: %d iterations\n' % N)

Solution analysis

Sparse codes

  • They can be arbitrary sparse by decreasing $\lambda_d$ and $\lambda_e$.
In [ ]:
nnz = np.count_nonzero(Z)
#nnz = np.sum(np.abs(Z) < 1e-4)
print('Sparsity of Z: %d non-zero entries out of %d entries, i.e. %.1f%%.' % (nnz, Z.size, 100.*nnz/Z.size))

plt.spy(Z, precision=0, aspect='auto')
plt.xlabel('N = %d samples' % N)
plt.ylabel('m = %d atoms' % m)


  • All constraints are indeed honored.
  • Only few atoms are actually used. This can already be seen via the sparse code sparsity pattern.
  • Learned atoms seem plausible but highly repetitive (keep in mind that the images are normalized).
In [ ]:
d = np.sqrt(np.sum(D*D, axis=0))
print('Constraints on D: %s' % np.alltrue(d <= 1))

plt.semilogy(d, '.')
plt.title('Dictionary atom norms')
plt.xlabel('Atom [1,m]')
plt.ylabel('Norm [0,1]')

plt.spy(D, precision=1e-7)
plt.xlabel('m = %d atoms' % m)
plt.ylabel('data dimensionality of n = %d' % n)

#plt.scatter to show intensity
In [ ]:
Nx = np.ceil(np.sqrt(m))
Ny = np.ceil(m / float(Nx))
for k in np.arange(m):
    plt.subplot(Ny, Nx, k)
    img = D[:,k].reshape(Np,Np)
    plt.imshow(img, cmap='gray')  # vmin=0, vmax=1 to disable normalization.


  • All constraints are indeed honored.
  • Depending on the conditions, the encoder ressembles the transpose of the dictionary, i.e. $E \approx D^T$.
In [ ]:
e = np.sqrt(np.sum(E*E, axis=0))
print('Constraints on E: %s' % np.alltrue(e <= 1))

plt.title('Encoder norms')
plt.ylabel('Norm [0,1]')

plt.spy(E, precision=1e-7)
plt.xlabel('data dimensionality of n = %d' % n)
plt.ylabel('m = %d atoms' % m)