Sparse energy auto-encoders

  • The definition of the algortihm behind our sparse energy auto-encoder model.
  • It is an unsupervised feature extraction tool which tries to find a good sparse representation in an efficient manner.
  • This notebook is meant to be imported by other notebooks for applications to image or audio data.
  • Modeled after sklearn Estimator class so that it can be integrated into an sklearn Pipeline. Note that matrix dimensions are inverted (code vs math) to follow sklearn conventions.

Algorithm

General problem:

  • given $X \in R^{n \times N}$,
  • solve $\min\limits_{Z \in R^{m \times N}, D \in R^{n \times m}, E \in R^{m \times n}} \frac{\lambda_d}{2} \|X - DZ\|_F^2 + \frac{\lambda_e}{2} \|Z - EX\|_F^2 + \lambda_s \|Z\|_1 + \frac{\lambda_g}{2} \text{tr}(Z^TLZ)$
  • s.t. $\|d_i\|_2 \leq 1$, $\|e_k\|_2 \leq 1$, $i = 1, \ldots, m$, $k = 1, \ldots, n$

which can be reduced to sparse coding with dictionary learning:

  • given $X \in R^{n \times N}$,
  • solve $\min\limits_{Z \in R^{m \times N}, D \in R^{n \times m}} \frac{\lambda_d}{2} \|X - DZ\|_F^2 + \lambda_s \|Z\|_1$
  • s.t. $\|d_i\|_2 \leq 1$, $i = 1, \ldots, m$

Observations:

  • Almost ten times faster (on comparison_xavier) using optimized linear algebra subroutines:
    • None: 9916s
    • ATLAS: 1335s (is memory bandwith limited)
    • OpenBLAS: 1371s (seems more CPU intensive than ATLAS)

Open questions:

  • First optimize for Z (last impl) or first for D/E (new impl) ?
    • Seem to converge much faster if Z optimized last (see comparison_xavier).
    • But two times slower.
    • In fit we optimize for parameters D, E so it makes sense to optimize them last.
    • Need to optimize for Z first if we initialize it with zeros.
  • Fast evaluation of la.norm(Z.T.dot(Z)). Cumulative to save memory ?
  • Consider adding an option for $E = D^T$
  • Use single precision, i.e. float32 ? Yes, it saves memory and speed up computation due to reduced memory bandwidth.
In [ ]:
import numpy as np
import numpy.linalg as la
from pyunlocbox import functions, solvers
import matplotlib.pyplot as plt
%matplotlib inline

class auto_encoder():
    """Sparse energy auto-encoder."""
    
    def __init__(self, m=100, ls=None, ld=None, le=None, lg=None,
                 rtol=1e-3, xtol=None, N_inner=100, N_outer=15):
        """
        Model hyper-parameters and solver stopping criteria.
        
        Model hyper-parameters:
            m:  number of atoms in the dictionary, sparse code length
            ld: weigth of the dictionary l2 penalty
            le: weigth of the encoder l2 penalty
            lg: weight of the graph smoothness
        
        Stopping criteria::
            rtol: objective function convergence
            xtol: model parameters convergence
            N_inner: hard limit of inner iterations
            N_outer: hard limit of outer iterations
        """
        self.m = m
        self.ls = ls
        self.ld = ld
        self.le = le
        self.lg = lg
        self.N_outer = N_outer
        
        # Solver common parameters.
        self.params = {'rtol':       rtol,
                       'xtol':       xtol,
                       'maxit':      N_inner,
                       'verbosity': 'NONE'}

    def _convex_functions(self, X, L, Z):
        """Define convex functions."""
        
        f = functions.proj_b2()
        self.f = functions.func()
        self.f._eval = lambda X: 0
        self.f._prox = lambda X,_: f._prox(X.T, 1).T
        #self.f._prox = lambda X,_: _normalize(X)
        
        if self.ld is not None:
            self.g_d = functions.norm_l2(lambda_=self.ld/2., A=Z, y=X, tight=False)
            self.g_z = functions.norm_l2(lambda_=self.ld/2., A=self.D.T, y=X.T, tight=False)
        else:
            self.g_z = functions.dummy()

        if self.le is not None:
            self.h_e = functions.norm_l2(lambda_=self.le/2., A=X, y=Z, tight=False)
            self.h_z = functions.norm_l2(lambda_=self.le/2., y=lambda: X.dot(self.E).T, tight=True)
        else:
            self.h_z = functions.dummy()

        if self.lg is not None:
            self.j_z = functions.func()
            # tr(A*B) = sum(A.*B^T).
            #self.j_z._eval = lambda Z: self.lg/2. * np.trace(Z.dot(L.dot(Z.T)))
            #self.j_z._eval = lambda Z: self.lg/2. * np.multiply(L.dot(Z.T), Z.T).sum()
            self.j_z._eval = lambda Z: self.lg/2. * np.einsum('ij,ji->', L.dot(Z.T), Z)
            self.j_z._grad = lambda Z: self.lg * L.dot(Z.T).T
        else:
            self.j_z = functions.dummy()

        self.ghj_z = functions.func()
        self.ghj_z._eval = lambda Z: self.j_z._eval(Z) + self.g_z._eval(Z) + self.h_z._eval(Z)
        self.ghj_z._grad = lambda Z: self.j_z._grad(Z) + self.g_z._grad(Z) + self.h_z._grad(Z)
        
        if self.ls is not None:
            self.i_z = functions.norm_l1(lambda_=self.ls)
        else:
            self.i_z = functions.dummy()

    def _minD(self, X, Z):
        """Convex minimization for D."""
        
        # Lipschitz continuous gradient. Faster if larger dim is 'inside'.
        B = self.ld * la.norm(Z.T.dot(Z))
        
        solver = solvers.forward_backward(step=1./B, method='FISTA')
        ret = solvers.solve([self.g_d, self.f], self.D, solver, **self.params)
        
        self.objective_d.extend(ret['objective'])
        self.objective_z.extend([[0,0]] * len(ret['objective']))
        self.objective_e.extend([[0,0]] * len(ret['objective']))
    
    def _minE(self, X, Z):
        """Convex minimization for E."""
        
        # Lipschitz continuous gradient. Faster if larger dim is 'inside'.
        B = self.le * la.norm(X.T.dot(X))
        
        solver = solvers.forward_backward(step=1./B, method='FISTA')
        ret = solvers.solve([self.h_e, self.f], self.E, solver, **self.params)
        
        self.objective_e.extend(ret['objective'])
        self.objective_z.extend([[0,0]] * len(ret['objective']))
        self.objective_d.extend([[0,0]] * len(ret['objective']))
    
    def _minZ(self, X, L, Z):
        """Convex minimization for Z."""
        
        B_e = self.le if self.le is not None else 0
        B_d = self.ld * la.norm(self.D.T.dot(self.D)) if self.ld is not None else 0
        B_g = self.lg * np.sqrt((L.data**2).sum()) if self.lg is not None else 0
        B = B_d + B_e + B_g
        
        solver = solvers.forward_backward(step=1./B, method='FISTA')
        ret = solvers.solve([self.ghj_z, self.i_z], Z.T, solver, **self.params)
        
        self.objective_z.extend(ret['objective'])
        self.objective_d.extend([[0,0]] * len(ret['objective']))
        self.objective_e.extend([[0,0]] * len(ret['objective']))
        
    def fit_transform(self, X, L):
        """
        Fit the model parameters (dictionary, encoder and graph)
        given training data.
        
        Parameters
        ----------
        X : ndarray, shape (N, n)
            Training vectors, where N is the number of samples
            and n is the number of features.
        L : scipy.sparse, shape (N, N)
            The Laplacian matrix of the graph.
            
        Returns
        -------
        Z : ndarray, shape (N, m)
            Sparse codes (a by-product of training), where N
            is the number of samples and m is the number of atoms.
        """
        N, n = X.shape
        
        def _normalize(X, axis=1):
            """Normalize the selected axis of an ndarray to unit norm."""
            return X / np.sqrt(np.sum(X**2, axis))[:,np.newaxis]
        
        # Model parameters initialization.
        if self.ld is not None:
            self.D = _normalize(np.random.uniform(size=(self.m, n)).astype(X.dtype))
        if self.le is not None:
            self.E = _normalize(np.random.uniform(size=(n, self.m)).astype(X.dtype))
        
        # Initial predictions.
        #Z = np.random.uniform(size=(N, self.m)).astype(X.dtype)
        Z = np.zeros(shape=(N, self.m), dtype=X.dtype)
        
        # Initialize convex functions.
        self._convex_functions(X, L, Z)
        
        # Objective functions.
        self.objective = []
        self.objective_g = []
        self.objective_h = []
        self.objective_i = []
        self.objective_j = []
        self.objective_z = []
        self.objective_d = []
        self.objective_e = []
        
        # Stopping criteria.
        crit = None
        niter = 0
        last = np.nan
        
        # Multi-variate non-convex optimization (outer loop).
        while not crit:
            niter += 1

            self._minZ(X, L, Z)

            if self.ld is not None:
                self._minD(X, Z)

            if self.le is not None:
                self._minE(X, Z)

            # Global objectives.
            self.objective_g.append(self.g_z.eval(Z.T))
            self.objective_h.append(self.h_z.eval(Z.T))
            self.objective_i.append(self.i_z.eval(Z.T))
            self.objective_j.append(self.j_z.eval(Z.T))
            
            if self.params['rtol'] is not None:
                current = 0
                for func in ['g', 'h', 'i', 'j']:
                    current += getattr(self, 'objective_'+func)[-1]
                relative = np.abs((current - last) / current)
                last = current
                if relative < self.params['rtol']:
                    crit = 'RTOL'

            if self.N_outer is not None and niter >= self.N_outer:
                crit = 'MAXIT'

        return Z
    
    def fit(self, X, L):
        """Fit to data without returning the transformed data."""
        self.fit_transform(X, L)
    
    def transform(self, X, L):
        """Predict sparse codes for each sample in X."""
        return self._transform_exact(X, L)
        
    def _transform_exact(self, X, L):
        """Most accurate but slowest prediction."""
        N = X.shape[0]
        Z = np.random.uniform(size=(N, self.m)).astype(X.dtype)
        self._convex_functions(X, L, Z)
        self._minZ(X, L, Z)
        return Z
    
    def _transform_approx(self, X, L):
        """Much faster approximation using only the encoder."""
        raise NotImplementedError('Not yet implemented')
    
    def inverse_transform(self, Z):
        """
        Return the data corresponding to the given sparse codes using
        the learned dictionary.
        """
        raise NotImplementedError('Not yet implemented')
    
    def plot_objective(self):
        """Plot the objective (cost, loss, energy) functions."""
        plt.figure(figsize=(8,5))
        plt.semilogy(np.asarray(self.objective_z)[:, 0], label='Z: data term')
        plt.semilogy(np.asarray(self.objective_z)[:, 1], label='Z: prior term')
        #plt.semilogy(np.sum(objective[:,0:2], axis=1), label='Z: sum')
        if self.ld is not None:
            plt.semilogy(np.asarray(self.objective_d)[:, 0], label='D: data term')
        if self.le is not None:
            plt.semilogy(np.asarray(self.objective_e)[:, 0], label='E: data term')
        iterations_inner = np.shape(self.objective_z)[0]
        plt.xlim(0, iterations_inner-1)
        plt.title('Sub-problems convergence')
        plt.xlabel('Iteration number (inner loops)')
        plt.ylabel('Objective function value')
        plt.grid(True); plt.legend(); plt.show()
        print('Inner loop: {} iterations'.format(iterations_inner))

        plt.figure(figsize=(8,5))
        def rdiff(a, b):
            print('rdiff: {}'.format(abs(a - b) / a))
        if self.ld is not None:
            name = 'g(Z) = ||X-DZ||_2^2'
            plt.semilogy(self.objective_g, '.-', label=name)
            print(name + ' = {:e}'.format(self.objective_g[-1]))
            rdiff(self.objective_g[-1], self.g_d.eval(self.D))
        if self.le is not None:
            name = 'h(Z) = ||Z-EX||_2^2'
            plt.semilogy(self.objective_h, '.-', label=name)
            print(name + ' = {:e}'.format(self.objective_h[-1]))
            rdiff(self.objective_h[-1], self.h_e.eval(self.E))
        name = 'i(Z) = ||Z||_1'
        plt.semilogy(self.objective_i, '.-', label=name)
        print(name + ' = {:e}'.format(self.objective_i[-1]))
        if self.lg is not None:
            name = 'j(Z) = tr(Z^TLZ)'
            plt.semilogy(self.objective_j, '.-', label=name)
            print(name + ' = {:e}'.format(self.objective_j[-1]))
        iterations_outer = len(self.objective_i)
        plt.xlim(0, iterations_outer-1)
        plt.title('Objectives convergence')
        plt.xlabel('Iteration number (outer loop)')
        plt.ylabel('Objective function value')
        plt.grid(True); plt.legend(loc='best'); plt.show()
        
        plt.figure(figsize=(8,5))
        objective = np.zeros((iterations_outer))
        for obj in ['g', 'h', 'i', 'j']:
            objective += np.asarray(getattr(self, 'objective_' + obj))
        print('Global objective: {:e}'.format(objective[-1]))
        plt.plot(objective, '.-')
        plt.xlim(0, iterations_outer-1)
        plt.title('Global convergence')
        plt.xlabel('Iteration number (outer loop)')
        plt.ylabel('Objective function value')
        plt.grid(True); plt.show()
        print('Outer loop: {} iterations\n'.format(iterations_outer))
        
        return (iterations_inner, iterations_outer,
                self.objective_g[-1], self.objective_h[-1],
                self.objective_i[-1], self.objective_j[-1])

Tools for solution analysis

Tools to show model parameters, sparse codes and objective function. The auto_encoder class solely contains the core algorithm (and a visualization of the convergence).

In [ ]:
def sparse_codes(Z, tol=0):
    """Show the sparsity of the sparse codes."""
    N, m = Z.shape
    
    print('Z in [{}, {}]'.format(np.min(Z), np.max(Z)))
    
    if tol is 0:
        nnz = np.count_nonzero(Z)
    else:
        nnz = np.sum(np.abs(Z) > tol)
    sparsity = 100.*nnz/Z.size
    print('Sparsity of Z: {:,} non-zero entries out of {:,} entries, '
          'i.e. {:.1f}%.'.format(nnz, Z.size, sparsity))

    try:
        plt.figure(figsize=(8,5))
        plt.spy(Z.T, precision=tol, aspect='auto')
        plt.xlabel('N = {} samples'.format(N))
        plt.ylabel('m = {} atoms'.format(m))
        plt.show()
    except MemoryError:
        pass
    
    return sparsity
    
def dictenc(D, tol=1e-5, enc=False):
    """Show the norms and sparsity of the learned dictionary or encoder."""
    m, n = D.shape
    name = 'D' if not enc else 'E'
    
    print('{} in [{}, {}]'.format(name, np.min(D), np.max(D)))
    
    d = np.sqrt(np.sum(D**2, axis=1))
    print('{} in [{}, {}]'.format(name.lower(), np.min(d), np.max(d)))
    print('Constraints on {}: {}'.format(name, np.alltrue(d <= 1+tol)))
    
    plt.figure(figsize=(8,5))
    plt.plot(d, 'b.')
    #plt.ylim(0.5, 1.5)
    plt.xlim(0, m-1)
    if not enc:
        plt.title('Dictionary atom norms')
        plt.xlabel('Atom [1,m={}]'.format(m))
    else:
        plt.title('Encoder column norms')
        plt.xlabel('Column [1,n={}]'.format(m))
    plt.ylabel('Norm [0,1]')
    plt.grid(True); plt.show()
    plt.show()

    plt.figure(figsize=(8,5))
    plt.spy(D.T, precision=1e-2, aspect='auto')
    if not enc:
        plt.xlabel('m = {} atoms'.format(m))
        plt.ylabel('data dimensionality of n = {}'.format(n))
    else:
        plt.xlabel('n = {} columns'.format(m))
        plt.ylabel('data dimensionality of m = {}'.format(n))
        
    plt.show()
    
    #plt.scatter to show intensity
    
def atoms(D, Np=None):
    """
    Show dictionary or encoder atoms.
    
    2D atoms if Np is not None, else 1D atoms.
    """
    m, n = D.shape
    
    fig = plt.figure(figsize=(8,8))
    Nx = np.ceil(np.sqrt(m))
    Ny = np.ceil(m / float(Nx))
    for k in np.arange(m):
        ax = fig.add_subplot(Ny, Nx, k)
        if Np is not None:
            img = D[k,:].reshape(Np, Np)
            ax.imshow(img, cmap='gray')  # vmin=0, vmax=1 to disable normalization.
            ax.axis('off')
        else:
            ax.plot(D[k,:])
            ax.set_xlim(0, n-1)
            ax.set_ylim(-1, 1)
            ax.set_xticks([])
            ax.set_yticks([])
    return fig

Unit tests

Test the auto-encoder class and tools.

In [ ]:
if False:
    # ldd numpy/core/_dotblas.so
    try:
        import numpy.core._dotblas
        print 'fast BLAS'
    except ImportError:
        print 'slow BLAS'

    print np.__version__
    np.__config__.show()

if False:
#if __name__ is '__main__':
    import time
    import scipy.sparse
    
    # Data.
    N, n = 25, 16
    X = np.random.normal(size=(N, n))
    
    # Graph.
    W = np.random.uniform(size=(N, N))  # W in [0,1].
    W = np.maximum(W, W.T)  # Symmetric weight matrix, i.e. undirected graph.
    D = np.diag(W.sum(axis=0))  # Diagonal degree matrix.
    L = D - W  # Symmetric and positive Laplacian.
    L = scipy.sparse.csr_matrix(L)

    # Algorithm.
    auto_encoder(m=20, ls=1, le=1, rtol=1e-5, xtol=None).fit(X, L)
    auto_encoder(m=20, ld=1, rtol=1e-5, xtol=None).fit(X, L)
    auto_encoder(m=20, lg=1, rtol=None, xtol=None).fit(X, L)
    auto_encoder(m=20, lg=1, ld=1, rtol=1e-5, xtol=1e-5).fit(X, L)
    ae = auto_encoder(m=20, ls=5, ld=10, le=100, lg=1, rtol=1e-5, N_outer=20)
    tstart = time.time()
    Z = ae.fit_transform(X, L)
    print('Elapsed time: {:.3f} seconds'.format(time.time() - tstart))
    ret = ae.plot_objective()
    iterations_inner, iterations_outer = ret[:2]
    objective_g, objective_h, objective_i, objective_j = ret[2:]
    
    # Reproducable results (min_Z given X, L, D, E is convex).
    err = la.norm(Z - ae.transform(X, L)) / np.sqrt(Z.size) #< 1e-3
    print('Error: {}'.format(err))

    # Results visualization.
    sparse_codes(Z)
    dictenc(ae.D)
    dictenc(ae.E, enc=True)
    atoms(ae.D, 4)  # 2D atoms.
    atoms(ae.D)  # 1D atoms.
    atoms(ae.E)