MOSEK ApS

Regularized Wasserstein Barycenters using Mosek and the exponential cone

In a previous notebook related to Wasserstein distances we defined the linear optimization problem of computing the Wasserstein barycenter of a set of discrete measures. Here we solve an entropy-regularized variant of the same problem and to demonstrate the exponential cone capabilities of MOSEK. We also use this problem to compare Fusion and CVXPY.

As a reminder, the $p$-th order Wasserstein distance $W_p(\mu,\upsilon)$ between discrete probability distributions $\mu,\upsilon$ is the objective value of the following problem:

$$ \mbox{minimize} \quad \sum_{i=1}\sum_{j=1} D(X_i,Y_j)^p\pi_{ij}$$
$$ \mbox{st.} \quad \sum_{j=1} \pi_{ij} = \mu_i , \quad i = 1,2,..n $$
$$ \quad \sum_{i=1} \pi_{ij} = \upsilon_j, \quad j = 1,2,..m $$
$$ \pi_{ij} \geq 0, \quad \forall_{i,j}$$
where $D(X_i,Y_j)$ is the distance function.

Wasserstein Barycenter with regularization

The entropy regularized barycenter problem with $p=2$ is:
$$ \mbox{minimize} \quad \frac1N \sum_{i,j,k}^{N} D(X_i,Y_j)^2\pi_{ij}^k + \frac1\lambda\sum_{i,j,k} \pi_{ij}^k\log(\pi_{ij}^k)$$
$$\mbox{st.} \quad \sum_{j=1} \pi_{ij}^{k} = \mu_i, \quad \forall_{k,i} \quad (1)$$
$$ \quad \sum_{i=1} \pi_{ij}^{k} = \upsilon_j^{k}, \quad \forall_{k,j} \quad (2) $$
$$ \pi_{ij}^{k} \geq 0 \quad \forall_{k,i,j}$$
where $D(X_i,Y_j)$ is the euclidian distance between pixels, $\lambda = median(D(X_i,Y_j))$ and $N$ is the number of samples.

Without the entropy term the problem is just the linear problem of computing a distribution $\mu$ minimizing the sum of distances to $\upsilon_i$, as studied in our other notebook. Entropy regularization was suggested to us by Stefano Gualandi and appears for example in the paper by Cuturi and Doucet http://proceedings.mlr.press/v32/cuturi14.pdf. This paper contains also more details about the choice of $\lambda$. Also more detailed information about LP aproach to Wasserstein metric can be found in Stefano Gualandi's blogpost.

In this problem, Wasserstein Barycenter of Three's are visualized using images with size $28x28$ using $2$ handwriten '3' digits from MNIST database. Computations are carried out by Intel(R) Xeon(R) CPU E5-2687W v4 @ 3.00GHz processor.

In [1]:
import struct
import numpy as np
import pandas as pd
%matplotlib inline
import matplotlib.pyplot as plt

#Define the number of images for the barycenter calculation
n=2
number = 3

#Read the images from the file
def read_idx(filename):
    with open(filename, 'rb') as f:
        zero, data_type, dims = struct.unpack('>HBB', f.read(4))
        shape = tuple(struct.unpack('>I', f.read(4))[0] for d in range(dims))
        return np.frombuffer(f.read(), dtype=np.uint8).reshape(shape)
    
data = read_idx('train-images.idx3-ubyte')
labels = read_idx('train-labels.idx1-ubyte')
#Select the images
digits = data[labels == number]
train = digits[:n]

plt.figure(figsize=(20,10))
for i in range(10):
    plt.subplot(2,5,i+1)
    plt.imshow(digits[np.random.randint(0,digits.shape[0])])

Regularized Barycenters using Mosek Fusion

In [2]:
from mosek.fusion import *
import time
import sys

class Wasserstein_Fusion:
    
    def __init__(self):
        self.time = 0.0
        self.M = Model('Wasserstein')
        self.result = None
    
    
    def single_pmf(self, data = None, img=False):
        
        ''' Takes a image or array of images and extracts the probabilty mass function'''
        
        if not img:
            v=[]
            for image in data:
                arr = np.asarray(image).ravel(order='K')
                v.append(arr/np.sum(arr))
        else:
            v = np.asarray(data).ravel(order='K')
            v = v/np.sum(v)
        return v
    
    def ms_distance(self, m ,n, constant=False):
        
        ''' Squared Euclidean distance calculation between the pixels '''
        
        if constant:
            d = np.ones((m,m))
        else:
            d = np.empty((m,m))
            coor = []
            for i in range(n):
                for j in range(n):
                    coor.append(np.array([i,j]))
            for i in range(m):
                for j in range(m):
                    d[i][j] = np.linalg.norm(coor[i]-coor[j])**2
        return d
    
    def Wasserstein_Distance(self, bc ,data, img = False):
        
        ''' Calculation of wasserstein distance between a barycenter and an image by solving the minimization problem '''
    
        v = np.array(self.single_pmf(data, img))
        n = v.shape[0]
        d = self.ms_distance(n,data.shape[1])
        with Model('Wasserstein') as M:
            #Add variable
            pi = M.variable('pi',[n,n], Domain.greaterThan(0.0))
            
            #Add constraints
            M.constraint('c1' , Expr.sum(pi,0), Domain.equalsTo(v))
            M.constraint('c2' , Expr.sum(pi,1), Domain.equalsTo(bc))
            
            M.objective('Obj.' , ObjectiveSense.Minimize, Expr.dot(d, pi))
            
            M.solve()
            objective = M.primalObjValue()
            
        return objective
    
    def Wasserstein_BaryCenter(self,data):
        
        M = self.M
        start_time = time.time()
        k = data.shape[0]
        v = np.array(self.single_pmf(data))
        n = v.shape[1]
        d = self.ms_distance(n,data.shape[1])
        
        #Add variables   
        mu = M.variable('Mu', n, Domain.greaterThan(0.0))  
        pi = (M.variable('Pi', [k,n,n] , Domain.greaterThan(0.0)))
                
        #Add constraints    
        
        #Constraint (1)
        M.constraint('B', Expr.sub(Expr.sum(pi,1) , Var.repeat(mu,1,k).transpose()), Domain.equalsTo(0.0))
        #Constraint (2)
        M.constraint('C', Expr.sum(pi,2), Domain.equalsTo(v))
            
        M.objective('Obj' , ObjectiveSense.Minimize, Expr.sum(Expr.mul(Expr.mul(Expr.reshape(pi.asExpr(), k, n*n) , d.ravel()), 1/k)))
        
        M.setLogHandler(sys.stdout)
        M.solve()
        self.result = mu.level()
        M.selectedSolution(SolutionType.Interior)
        self.objective = M.primalObjValue()
        self.time = time.time() - start_time
        
        return mu.level()
    
    def Wasserstein_regBaryCenter(self,data, _lambda = None, relgap = None):
        
        M = self.M
        start_time = time.time()
        k = data.shape[0]
        v = np.array(self.single_pmf(data))
        n = v.shape[1]
        d = self.ms_distance(n,data.shape[1])
        
        if not _lambda:
            _lambda = 60/np.median(d.ravel())
        
        
        #Add variables   
        mu = M.variable('Mu', n, Domain.greaterThan(0.0))  
        pi = (M.variable('Pi', [k,n,n] , Domain.greaterThan(0.0)))
        z = M.variable('z', [k,n*n]) #Artificial variable 
        
        #Add constraints
        #Intermediate conic constraints in form z <= -pi log(pi) 
        for i in range(1,k+1):
            M.constraint(Expr.hstack(Expr.constTerm(n*n, 1.0),
                                     Expr.reshape(pi.asExpr(), k*n*n).slice(0+(i-1)*n*n, n*n*i),
                                     Expr.reshape(z.asExpr(), k*n*n).slice(0+(i-1)*n*n, n*n*i)),
                                     Domain.inPExpCone())
    
        #Constraint (1)
        M.constraint('B', Expr.sub(Expr.sum(pi,1) , Var.repeat(mu,1,k).transpose()), Domain.equalsTo(0.0))
        #Constraint (2)
        M.constraint('C', Expr.sum(pi,2), Domain.equalsTo(v))
        
            
        M.objective('Obj' , ObjectiveSense.Minimize, Expr.sum(Expr.mul(Expr.add(Expr.mul(Expr.reshape(pi.asExpr(), k, n*n), d.ravel()),Expr.mul(Expr.sum(z,1),-1/_lambda)), 1/k)))
        
        #relgap is set in case of approximation
        if relgap:
            M.setSolverParam("intpntCoTolRelGap", relgap)
        
        M.setLogHandler(sys.stdout)
        M.solve()
        self.result = mu.level()
        self.objective = M.primalObjValue()
        
        self.time = time.time() - start_time
        
        return self.result
        
        
        
    def reset(self):
        self.M = Model('Wasserstein')
    
In [3]:
fusion_model = Wasserstein_Fusion()
f_bc = fusion_model.Wasserstein_regBaryCenter(train)
print('\nTime Spent to solve problem with Fusion: \n {0}'.format(fusion_model.time))
print('Time Spent in solver: \n {0}'.format(fusion_model.M.getSolverDoubleInfo("optimizerTime")))
print('Objective: \n {0}'.format(fusion_model.objective))
Problem
  Name                   : Wasserstein     
  Objective sense        : min             
  Type                   : CONIC (conic optimization problem)
  Constraints            : 3691072         
  Cones                  : 1229312         
  Scalar variables       : 6147345         
  Matrix variables       : 0               
  Integer variables      : 0               

Optimizer started.
Presolve started.
Linear dependency checker started.
Linear dependency checker terminated.
Eliminator started.
Freed constraints in eliminator : 0
Eliminator terminated.
Eliminator - tries                  : 1                 time                   : 0.00            
Lin. dep.  - tries                  : 1                 time                   : 0.79            
Lin. dep.  - number                 : 1               
Presolve terminated. Time: 4.88    
Problem
  Name                   : Wasserstein     
  Objective sense        : min             
  Type                   : CONIC (conic optimization problem)
  Constraints            : 3691072         
  Cones                  : 1229312         
  Scalar variables       : 6147345         
  Matrix variables       : 0               
  Integer variables      : 0               

Optimizer  - threads                : 24              
Optimizer  - solved problem         : the primal      
Optimizer  - Constraints            : 1138
Optimizer  - Cones                  : 1229312
Optimizer  - Scalar variables       : 3687936           conic                  : 3687936         
Optimizer  - Semi-definite variables: 0                 scalarized             : 0               
Factor     - setup time             : 1.28              dense det. time        : 0.00            
Factor     - ML order time          : 0.01              GP order time          : 0.00            
Factor     - nonzeros before factor : 2.79e+05          after factor           : 3.41e+05        
Factor     - dense dim.             : 0                 flops                  : 1.17e+08        
ITE PFEAS    DFEAS    GFEAS    PRSTATUS   POBJ              DOBJ              MU       TIME  
0   6.3e+02  5.1e+02  2.4e+07  0.00e+00   2.213290906e+07   -1.586952925e+06  1.0e+00  10.10 
1   2.2e+02  1.8e+02  1.4e+07  -9.87e-01  1.997448375e+07   -3.251324601e+06  3.5e-01  12.82 
2   1.4e+02  1.1e+02  1.0e+07  -8.50e-01  1.706335861e+07   -4.132584878e+06  2.2e-01  15.44 
3   8.6e+01  6.9e+01  6.1e+06  -5.35e-01  1.255537920e+07   -4.350508351e+06  1.4e-01  17.98 
4   3.9e+01  3.1e+01  2.3e+06  -5.91e-02  6.634411697e+06   -3.290271335e+06  6.2e-02  20.60 
5   1.6e+01  1.3e+01  6.8e+05  4.56e-01   3.047731323e+06   -1.738035830e+06  2.5e-02  23.20 
6   9.2e+00  7.4e+00  3.2e+05  6.93e-01   1.871267186e+06   -1.107832915e+06  1.5e-02  25.98 
7   3.8e+00  3.0e+00  9.0e+04  7.82e-01   8.077963925e+05   -5.102524925e+05  5.9e-03  28.62 
8   1.0e+00  8.1e-01  1.4e+04  8.51e-01   2.027478294e+05   -1.783065556e+05  1.6e-03  31.46 
9   4.2e-01  3.4e-01  4.0e+03  8.87e-01   8.245159686e+04   -8.499836533e+04  6.6e-04  34.22 
10  1.2e-01  9.9e-02  6.9e+02  9.00e-01   2.254754497e+04   -2.962276989e+04  2.0e-04  36.82 
11  3.6e-02  2.9e-02  1.2e+02  9.02e-01   5.918291547e+03   -1.024285722e+04  5.7e-05  39.36 
12  1.1e-02  8.5e-03  2.0e+01  9.00e-01   1.600474776e+03   -3.421335906e+03  1.7e-05  41.97 
13  3.3e-03  2.6e-03  3.7e+00  9.03e-01   4.338461931e+02   -1.203269461e+03  5.2e-06  44.63 
14  9.2e-04  7.4e-04  5.9e-01  9.15e-01   9.098495781e+01   -3.977520685e+02  1.5e-06  47.24 
15  2.5e-04  2.0e-04  8.5e-02  9.22e-01   1.259274659e+00   -1.351784897e+02  3.9e-07  49.97 
16  7.1e-05  5.7e-05  1.4e-02  9.31e-01   -1.815269221e+01  -5.892007163e+01  1.1e-07  52.66 
17  1.6e-05  1.3e-05  1.6e-03  9.35e-01   -2.349890101e+01  -3.326483992e+01  2.6e-08  55.32 
18  3.2e-06  2.6e-06  1.5e-04  9.38e-01   -2.441819572e+01  -2.642396208e+01  5.1e-09  57.91 
19  7.5e-07  6.0e-07  1.7e-05  9.44e-01   -2.452957265e+01  -2.501505945e+01  1.2e-09  60.52 
20  1.8e-07  1.4e-07  2.0e-06  9.49e-01   -2.454254575e+01  -2.466072167e+01  2.8e-10  63.35 
21  3.6e-08  2.9e-08  1.9e-07  9.48e-01   -2.454079867e+01  -2.456580997e+01  5.7e-11  65.94 
22  9.3e-09  7.5e-09  2.6e-08  9.50e-01   -2.453852251e+01  -2.454516650e+01  1.5e-11  68.59 
23  2.3e-09  1.8e-09  3.3e-09  9.51e-01   -2.453774803e+01  -2.453943434e+01  3.6e-12  71.15 
24  7.2e-10  5.0e-10  4.7e-10  9.51e-01   -2.453754481e+01  -2.453801359e+01  9.8e-13  73.75 
25  7.1e-10  4.1e-10  3.6e-10  9.54e-01   -2.453752642e+01  -2.453791941e+01  8.1e-13  81.43 
26  3.5e-10  9.8e-11  4.3e-11  9.54e-01   -2.453745494e+01  -2.453755111e+01  1.9e-13  88.92 
27  7.1e-10  9.8e-11  4.3e-11  9.78e-01   -2.453745493e+01  -2.453755104e+01  1.9e-13  98.60 
28  7.4e-10  9.8e-11  4.3e-11  9.79e-01   -2.453745493e+01  -2.453755104e+01  1.9e-13  110.68
29  8.4e-10  9.8e-11  4.3e-11  9.79e-01   -2.453745492e+01  -2.453755103e+01  1.9e-13  121.76
30  1.1e-09  9.8e-11  4.3e-11  9.83e-01   -2.453745492e+01  -2.453755100e+01  1.9e-13  134.08
31  1.1e-09  9.8e-11  4.3e-11  9.79e-01   -2.453745492e+01  -2.453755099e+01  1.9e-13  145.08
32  1.1e-09  9.8e-11  4.3e-11  9.77e-01   -2.453745490e+01  -2.453755091e+01  1.9e-13  155.13
33  1.4e-09  9.8e-11  4.3e-11  9.78e-01   -2.453745488e+01  -2.453755083e+01  1.9e-13  164.86
34  1.4e-09  9.8e-11  4.3e-11  9.79e-01   -2.453745488e+01  -2.453755081e+01  1.9e-13  175.32
35  1.4e-09  9.8e-11  4.3e-11  9.77e-01   -2.453745485e+01  -2.453755066e+01  1.9e-13  184.89
36  1.4e-09  9.8e-11  4.3e-11  9.77e-01   -2.453745482e+01  -2.453755053e+01  1.9e-13  193.91
37  1.4e-09  9.8e-11  4.3e-11  9.80e-01   -2.453745481e+01  -2.453755051e+01  1.9e-13  204.86
38  1.4e-09  9.7e-11  4.3e-11  9.84e-01   -2.453745481e+01  -2.453755048e+01  1.9e-13  215.45
39  1.4e-09  9.7e-11  4.3e-11  9.79e-01   -2.453745480e+01  -2.453755047e+01  1.9e-13  226.12
40  1.4e-09  9.7e-11  4.3e-11  9.79e-01   -2.453745479e+01  -2.453755043e+01  1.9e-13  236.39
41  1.4e-09  9.7e-11  4.3e-11  9.80e-01   -2.453745478e+01  -2.453755039e+01  1.9e-13  246.32
42  1.4e-09  9.7e-11  4.3e-11  9.82e-01   -2.453745478e+01  -2.453755038e+01  1.9e-13  257.70
43  1.4e-09  9.7e-11  4.3e-11  9.77e-01   -2.453745477e+01  -2.453755033e+01  1.9e-13  267.63
44  1.4e-09  9.7e-11  4.3e-11  9.80e-01   -2.453745476e+01  -2.453755030e+01  1.9e-13  277.59
45  1.4e-09  9.7e-11  4.3e-11  9.79e-01   -2.453745476e+01  -2.453755027e+01  1.9e-13  288.02
46  1.4e-09  9.7e-11  4.3e-11  9.78e-01   -2.453745475e+01  -2.453755026e+01  1.9e-13  298.50
47  1.4e-09  9.7e-11  4.2e-11  9.78e-01   -2.453745472e+01  -2.453755009e+01  1.9e-13  307.32
48  1.4e-09  9.7e-11  4.2e-11  9.80e-01   -2.453745472e+01  -2.453755008e+01  1.9e-13  318.93
49  1.4e-09  9.7e-11  4.2e-11  9.79e-01   -2.453745472e+01  -2.453755008e+01  1.9e-13  329.79
50  1.4e-09  9.7e-11  4.2e-11  9.79e-01   -2.453745471e+01  -2.453755007e+01  1.9e-13  340.67
51  1.4e-09  9.7e-11  4.2e-11  9.78e-01   -2.453745471e+01  -2.453755007e+01  1.9e-13  351.40
52  1.4e-09  9.7e-11  4.2e-11  9.85e-01   -2.453745468e+01  -2.453754991e+01  1.9e-13  360.84
53  1.4e-09  9.7e-11  4.2e-11  9.86e-01   -2.453745468e+01  -2.453754991e+01  1.9e-13  372.95
54  7.9e-10  1.1e-11  1.7e-12  1.00e+00   -2.453743187e+01  -2.453744287e+01  2.1e-14  379.56
55  1.2e-09  1.1e-11  1.7e-12  1.00e+00   -2.453743185e+01  -2.453744281e+01  2.1e-14  390.36
56  1.2e-09  1.1e-11  1.6e-12  1.00e+00   -2.453743181e+01  -2.453744264e+01  2.1e-14  400.86
57  1.4e-09  1.1e-11  1.6e-12  1.00e+00   -2.453743175e+01  -2.453744243e+01  2.0e-14  410.98
58  1.4e-09  1.1e-11  1.6e-12  1.00e+00   -2.453743175e+01  -2.453744242e+01  2.0e-14  423.31
59  1.4e-09  1.1e-11  1.6e-12  1.00e+00   -2.453743175e+01  -2.453744241e+01  2.0e-14  436.25
60  1.4e-09  1.1e-11  1.6e-12  1.00e+00   -2.453743175e+01  -2.453744241e+01  2.0e-14  449.46
61  6.7e-10  1.0e-12  4.6e-14  1.00e+00   -2.453742837e+01  -2.453742939e+01  1.8e-15  456.59
62  6.3e-10  1.0e-12  4.6e-14  1.00e+00   -2.453742837e+01  -2.453742939e+01  1.8e-15  469.17
63  8.4e-10  1.0e-12  4.6e-14  1.00e+00   -2.453742837e+01  -2.453742939e+01  1.8e-15  481.70
64  1.0e-09  1.0e-12  4.6e-14  1.00e+00   -2.453742837e+01  -2.453742939e+01  1.8e-15  494.25
65  1.1e-09  1.0e-12  4.6e-14  1.00e+00   -2.453742837e+01  -2.453742939e+01  1.8e-15  506.36
66  1.3e-09  1.0e-12  4.6e-14  1.00e+00   -2.453742837e+01  -2.453742939e+01  1.8e-15  519.57
67  1.2e-09  1.0e-12  4.6e-14  1.00e+00   -2.453742837e+01  -2.453742939e+01  1.8e-15  531.70
68  1.2e-09  1.0e-12  4.6e-14  1.00e+00   -2.453742837e+01  -2.453742939e+01  1.8e-15  544.16
Optimizer terminated. Time: 558.63  


Interior-point solution summary
  Problem status  : PRIMAL_AND_DUAL_FEASIBLE
  Solution status : OPTIMAL
  Primal.  obj: -2.4537428373e+01   nrm: 1e+00    Viol.  con: 1e-07    var: 0e+00    cones: 2e-12  
  Dual.    obj: -2.4537429385e+01   nrm: 8e+02    Viol.  con: 0e+00    var: 2e-11    cones: 0e+00  

Time Spent to solve problem with Fusion: 
 592.5434217453003
Time Spent in solver: 
 558.629515171051
Objective: 
 -24.53742837253434
In [4]:
nonReg_model = Wasserstein_Fusion()
nonReg = nonReg_model.Wasserstein_BaryCenter(train)
print('\nTime Spent to solve non-regularized problem with Fusion: \n {0}'.format(nonReg_model.time))
print('Time Spent in solver: \n {0}'.format(nonReg_model.M.getSolverDoubleInfo("optimizerTime")))
print('The average Wasserstein distance between digits and the barycenter: \n {0}'.format(nonReg_model.objective))
Problem
  Name                   : Wasserstein     
  Objective sense        : min             
  Type                   : LO (linear optimization problem)
  Constraints            : 3136            
  Cones                  : 0               
  Scalar variables       : 1230097         
  Matrix variables       : 0               
  Integer variables      : 0               

Optimizer started.
Presolve started.
Linear dependency checker started.
Linear dependency checker terminated.
Eliminator started.
Freed constraints in eliminator : 0
Eliminator terminated.
Eliminator - tries                  : 1                 time                   : 0.00            
Lin. dep.  - tries                  : 1                 time                   : 0.03            
Lin. dep.  - number                 : 1               
Presolve terminated. Time: 0.96    
Problem
  Name                   : Wasserstein     
  Objective sense        : min             
  Type                   : LO (linear optimization problem)
  Constraints            : 3136            
  Cones                  : 0               
  Scalar variables       : 1230097         
  Matrix variables       : 0               
  Integer variables      : 0               

Optimizer  - threads                : 24              
Optimizer  - solved problem         : the primal      
Optimizer  - Constraints            : 1138
Optimizer  - Cones                  : 0
Optimizer  - Scalar variables       : 278320            conic                  : 0               
Optimizer  - Semi-definite variables: 0                 scalarized             : 0               
Factor     - setup time             : 0.11              dense det. time        : 0.00            
Factor     - ML order time          : 0.01              GP order time          : 0.00            
Factor     - nonzeros before factor : 2.79e+05          after factor           : 3.41e+05        
Factor     - dense dim.             : 0                 flops                  : 1.15e+08        
ITE PFEAS    DFEAS    GFEAS    PRSTATUS   POBJ              DOBJ              MU       TIME  
0   1.8e+04  1.2e+04  5.7e+08  0.00e+00   2.533143200e+07   0.000000000e+00   5.1e+02  1.25  
1   1.8e+00  1.2e+00  5.8e+04  -1.00e+00  2.530057802e+07   -1.985024391e+04  5.2e-02  1.32  
2   4.0e-03  2.6e-03  1.3e+02  -6.05e-01  6.809717409e+04   -3.693719644e+04  1.1e-04  1.42  
3   1.1e-03  7.4e-04  3.7e+01  2.56e+02   1.601417700e+02   -1.716296687e+01  3.3e-05  1.54  
4   5.7e-04  3.7e-04  1.8e+01  1.42e+00   7.400571125e+01   -6.338392658e+00  1.6e-05  1.66  
5   7.5e-05  4.9e-05  2.4e+00  1.20e+00   9.664295656e+00   -2.705164314e-01  2.2e-06  1.77  
6   2.1e-05  1.4e-05  6.8e-01  1.06e+00   3.422966212e+00   6.973656981e-01   6.1e-07  1.87  
7   4.1e-06  1.8e-06  1.3e-01  1.02e+00   1.573043800e+00   1.067352406e+00   1.1e-07  1.95  
8   8.6e-07  3.7e-07  2.6e-02  1.00e+00   1.225114421e+00   1.119539692e+00   2.4e-08  2.06  
9   1.5e-07  6.3e-08  4.5e-03  1.00e+00   1.147377919e+00   1.129274252e+00   4.0e-09  2.15  
10  6.3e-09  2.9e-09  1.9e-04  1.00e+00   1.131288555e+00   1.130509098e+00   1.7e-10  2.21  
11  7.4e-13  3.7e-13  2.3e-08  1.00e+00   1.130528318e+00   1.130528226e+00   2.0e-14  2.27  
12  2.2e-16  6.3e-14  1.5e-12  1.00e+00   1.130528228e+00   1.130528228e+00   2.0e-18  2.31  
Basis identification started.
Basis identification terminated. Time: 0.07
Optimizer terminated. Time: 2.72    


Interior-point solution summary
  Problem status  : PRIMAL_AND_DUAL_FEASIBLE
  Solution status : OPTIMAL
  Primal.  obj: 1.1305282283e+00    nrm: 1e+00    Viol.  con: 9e-16    var: 0e+00  
  Dual.    obj: 1.1305282283e+00    nrm: 8e+02    Viol.  con: 0e+00    var: 2e-13  

Basic solution summary
  Problem status  : PRIMAL_AND_DUAL_FEASIBLE
  Solution status : OPTIMAL
  Primal.  obj: 1.1305282283e+00    nrm: 1e+00    Viol.  con: 2e-17    var: 0e+00  
  Dual.    obj: 1.1305282283e+00    nrm: 8e+02    Viol.  con: 0e+00    var: 1e-11  

Time Spent to solve non-regularized problem with Fusion: 
 8.922340393066406
Time Spent in solver: 
 2.7243459224700928
The average Wasserstein distance between digits and the barycenter: 
 1.1305282283356233
In [5]:
plt.figure(figsize=(15,8))
plt.subplot(1,3,1)
plt.imshow(np.reshape(f_bc,(28,28)))
plt.title('Regularized Barycenter')
plt.subplot(1,3,2)
plt.imshow(np.reshape(nonReg, (28,28)))
plt.title('Non-Regularized Interior Point Barycenter')
plt.imshow(np.reshape(nonReg, (28,28)))
plt.subplot(1,3,3)
plt.title('Non-Regularized Basic Point Barycenter')
plt.imshow(np.reshape(nonReg_model.result, (28,28)))
plt.show()

$\quad$ The interiror point solution is different than the basic solution, however this is just a coincidence. The interior point solution gives the convex combination of the extreme points if there is infinetly many optimal solutions but this is not always the case in this problem.

$\quad $ Solving the problem even for just 2 images takes a long time. However, when the output of the iterations are investigated the solver is spending most of the time in earning very little improvements. Since the regularization term is an artificial addition to the problem, it is sensible to test if the approximation with a little loss of accuracy effects the values and images significantly or not, by using "intpntCoTolRelGap" parameter. This parameter controls the relative gap termination tolerance of the conic optimizer. In the next problem the parameter is increased from 1.0e-7 to 1.0e-3 in order to terminate Mosek with a little bit less accuracy

Obtaining approximate values

In [6]:
fusion_model2 = Wasserstein_Fusion()
f_bc2 = fusion_model2.Wasserstein_regBaryCenter(train, relgap = "1.0e-3")
print('\nTime Spent to solve problem with Fusion: \n {0}'.format(fusion_model2.time))
print('Time Spent in solver: \n {0}'.format(fusion_model2.M.getSolverDoubleInfo("optimizerTime")))
print('Objective: \n {0}'.format(fusion_model2.objective))
Problem
  Name                   : Wasserstein     
  Objective sense        : min             
  Type                   : CONIC (conic optimization problem)
  Constraints            : 3691072         
  Cones                  : 1229312         
  Scalar variables       : 6147345         
  Matrix variables       : 0               
  Integer variables      : 0               

Optimizer started.
Presolve started.
Linear dependency checker started.
Linear dependency checker terminated.
Eliminator started.
Freed constraints in eliminator : 0
Eliminator terminated.
Eliminator - tries                  : 1                 time                   : 0.00            
Lin. dep.  - tries                  : 1                 time                   : 0.75            
Lin. dep.  - number                 : 1               
Presolve terminated. Time: 4.73    
Problem
  Name                   : Wasserstein     
  Objective sense        : min             
  Type                   : CONIC (conic optimization problem)
  Constraints            : 3691072         
  Cones                  : 1229312         
  Scalar variables       : 6147345         
  Matrix variables       : 0               
  Integer variables      : 0               

Optimizer  - threads                : 24              
Optimizer  - solved problem         : the primal      
Optimizer  - Constraints            : 1138
Optimizer  - Cones                  : 1229312
Optimizer  - Scalar variables       : 3687936           conic                  : 3687936         
Optimizer  - Semi-definite variables: 0                 scalarized             : 0               
Factor     - setup time             : 1.34              dense det. time        : 0.00            
Factor     - ML order time          : 0.01              GP order time          : 0.00            
Factor     - nonzeros before factor : 2.79e+05          after factor           : 3.41e+05        
Factor     - dense dim.             : 0                 flops                  : 1.17e+08        
ITE PFEAS    DFEAS    GFEAS    PRSTATUS   POBJ              DOBJ              MU       TIME  
0   6.3e+02  5.1e+02  2.4e+07  0.00e+00   2.213290906e+07   -1.586952925e+06  1.0e+00  9.79  
1   2.2e+02  1.8e+02  1.4e+07  -9.87e-01  1.997448375e+07   -3.251324601e+06  3.5e-01  12.58 
2   1.4e+02  1.1e+02  1.0e+07  -8.50e-01  1.706335861e+07   -4.132584878e+06  2.2e-01  15.35 
3   8.6e+01  6.9e+01  6.1e+06  -5.35e-01  1.255537920e+07   -4.350508351e+06  1.4e-01  18.28 
4   3.9e+01  3.1e+01  2.3e+06  -5.91e-02  6.634411697e+06   -3.290271335e+06  6.2e-02  20.98 
5   1.6e+01  1.3e+01  6.8e+05  4.56e-01   3.047731323e+06   -1.738035830e+06  2.5e-02  23.66 
6   9.2e+00  7.4e+00  3.2e+05  6.93e-01   1.871267186e+06   -1.107832915e+06  1.5e-02  26.15 
7   3.8e+00  3.0e+00  9.0e+04  7.82e-01   8.077963925e+05   -5.102524925e+05  5.9e-03  28.85 
8   1.0e+00  8.1e-01  1.4e+04  8.51e-01   2.027478294e+05   -1.783065556e+05  1.6e-03  31.64 
9   4.2e-01  3.4e-01  4.0e+03  8.87e-01   8.245159686e+04   -8.499836533e+04  6.6e-04  34.18 
10  1.2e-01  9.9e-02  6.9e+02  9.00e-01   2.254754497e+04   -2.962276989e+04  2.0e-04  36.94 
11  3.6e-02  2.9e-02  1.2e+02  9.02e-01   5.918291547e+03   -1.024285722e+04  5.7e-05  39.48 
12  1.1e-02  8.5e-03  2.0e+01  9.00e-01   1.600474776e+03   -3.421335906e+03  1.7e-05  42.19 
13  3.3e-03  2.6e-03  3.7e+00  9.03e-01   4.338461931e+02   -1.203269461e+03  5.2e-06  44.90 
14  9.2e-04  7.4e-04  5.9e-01  9.15e-01   9.098495781e+01   -3.977520685e+02  1.5e-06  47.40 
15  2.5e-04  2.0e-04  8.5e-02  9.22e-01   1.259274659e+00   -1.351784897e+02  3.9e-07  49.96 
16  7.1e-05  5.7e-05  1.4e-02  9.31e-01   -1.815269221e+01  -5.892007163e+01  1.1e-07  52.59 
17  1.6e-05  1.3e-05  1.6e-03  9.35e-01   -2.349890101e+01  -3.326483992e+01  2.6e-08  55.25 
18  3.2e-06  2.6e-06  1.5e-04  9.38e-01   -2.441819572e+01  -2.642396208e+01  5.1e-09  57.85 
19  7.5e-07  6.0e-07  1.7e-05  9.44e-01   -2.452957265e+01  -2.501505945e+01  1.2e-09  60.41 
20  1.8e-07  1.4e-07  2.0e-06  9.49e-01   -2.454254575e+01  -2.466072167e+01  2.8e-10  62.93 
21  3.6e-08  2.9e-08  1.9e-07  9.48e-01   -2.454079867e+01  -2.456580997e+01  5.7e-11  65.49 
22  9.3e-09  7.5e-09  2.6e-08  9.50e-01   -2.453852251e+01  -2.454516650e+01  1.5e-11  68.06 
23  2.3e-09  1.8e-09  3.3e-09  9.51e-01   -2.453774803e+01  -2.453943434e+01  3.6e-12  70.63 
24  7.2e-10  5.0e-10  4.7e-10  9.51e-01   -2.453754481e+01  -2.453801359e+01  9.8e-13  73.16 
Optimizer terminated. Time: 75.21   


Interior-point solution summary
  Problem status  : PRIMAL_AND_DUAL_FEASIBLE
  Solution status : OPTIMAL
  Primal.  obj: -2.4537544810e+01   nrm: 1e+00    Viol.  con: 1e-08    var: 0e+00    cones: 6e-12  
  Dual.    obj: -2.4538013590e+01   nrm: 8e+02    Viol.  con: 0e+00    var: 1e-08    cones: 0e+00  

Time Spent to solve problem with Fusion: 
 109.92306613922119
Time Spent in solver: 
 75.20934391021729
Objective: 
 -24.53754480969593
In [7]:
plt.figure(figsize=(10,10))
plt.subplot(1,2,1)
fus_bc = np.reshape(f_bc,(28,28))
plt.imshow(fus_bc)
plt.title('Optimal Barycenter')
plt.subplot(1,2,2)
fus_bc2 = np.reshape(f_bc,(28,28))
plt.imshow(fus_bc)
plt.title('Approximate Barycenter')
plt.show()
In [8]:
error = pd.Series(f_bc - f_bc2)
plt.plot(error)
plt.title('Error of Aproximation')
plt.xlabel("Pixels")
plt.ylabel('Error Values')
pd.DataFrame(error.describe(), columns=['Stats']).transpose()
Out[8]:
count mean std min 25% 50% 75% max
Stats 784.0 -2.565828e-09 8.372986e-09 -2.593707e-08 -8.399538e-09 -3.616124e-09 -4.976524e-10 2.277483e-08
In [9]:
plt.figure(figsize=(12,7))

total_t = [fusion_model.time, fusion_model2.time]
solver_t = [fusion_model.M.getSolverDoubleInfo("optimizerTime"), fusion_model2.M.getSolverDoubleInfo("optimizerTime")]
#Total time plot
plt.subplot(1,2,1)
plt.bar(['Optimal', 'Approximation'], height= total_t,
        width=0.4, color=(0.3, 0.6, 0.2, 0.5))
plt.ylabel("Total Time (s)")
plt.title("Comparison of Total Time")

#Solver time plot
plt.subplot(1,2,2)
plt.bar(['Optimal', 'Approximation'], height=solver_t,
        width=0.4, color=(0.5, 0.6, 0.9, 0.8))
plt.ylabel("Solver Time (s)")
plt.title("Comparison of Solver Time")
plt.show()

$\quad$The two barycenter images that obtained from the optimal solution and approximate solution seem almost identical to each other. In addition statistical description of values of the error between them is presented above with a plot. The mean and the even extreme values are small and indicates that the reduction of time obtained by approximation totally compensates the error.

Modeling the same problem with CVXPY

In [10]:
import cvxpy as cp
import time
class Wasserstein_CVXPY:
    
    def __init__(self):
        self.time = 0.0
        self.result = None
        self.prob = None

    def single_pmf(self, data = None, img=False):
        
        ''' Takes a image or array of images and extracts the probabilty mass function'''
        
        if not img:
            v=[]
            for image in data:
                arr = np.asarray(image).ravel(order='K')
                v.append(arr/np.sum(arr))
        else:
            v = np.asarray(data).ravel(order='K')
            v = v/np.sum(v)
        return v
    
    def ms_distance(self, m ,n, constant=False):
        
        ''' Squared Euclidean distance calculation between the pixels '''
        
        if constant:
            d = np.ones((m,m))
        else:
            d = np.empty((m,m))
            coor = []
            for i in range(n):
                for j in range(n):
                    coor.append(np.array([i,j]))
            for i in range(m):
                for j in range(m):
                    d[i][j] = np.linalg.norm(coor[i]-coor[j])**2
        return d
    
    def Wasserstein_Distance(self, bc ,data, img = False):
        
        ''' Calculation of wasserstein distance between a barycenter and an image by solving 
            the minimization problem '''
    
        v = np.array(self.single_pmf(data, img))
        n = v.shape[0]
        d = self.ms_distance(n,data.shape[1])
        
        pi = cp.Variable((n,n), nonneg=True)
        obj = cp.Minimize((np.ones(n).T @ cp.multiply(d,pi) @ np.ones(n)))
        
        Cons=[]
        Cons.append((np.ones(n) @ pi).T == bc)
        Cons.append((pi @ np.ones(n)) == v)
        
        prob = cp.Problem(obj, constraints= Cons)
        
        return prob.solve(solver=cp.MOSEK, verbose = True)
        
    def Wasserstein_BaryCenter(self,data):
        
        ''' Calculation of wasserstein barycenter of given images by solving the minimization problem '''
        
        start_time = time.time()
        k = data.shape[0]
        v = np.array(self.single_pmf(data))
        n = v.shape[1]
        d = self.ms_distance(n,data.shape[1])
        
        #Add variables
        pi= []
        t= []
        mu = cp.Variable(n, nonneg = True)
        for i in range(k):
            pi.append(cp.Variable((n,n), nonneg = True))
            t.append(cp.Variable(nonneg = True))
            
        obj = cp.Minimize(np.sum(t)/k)
        
        #Add constraints
        Cons=[]
        for i in range(k):
            Cons.append( t[i] >= np.ones(n).T @  cp.multiply(d,pi[i]) @ np.ones(n) ) #Constraint (1)
            Cons.append( (np.ones(n) @ pi[i]).T == mu)                               #Constraint (2)
            Cons.append( (pi[i] @ np.ones(n)) == v[i])                               #Constraint (3)
            
        self.prob = cp.Problem(obj, constraints= Cons)
        self.result = self.prob.solve(solver=cp.MOSEK,verbose = True)
        self.time = time.time() - start_time
        
        return mu.value
    
    def Wasserstein_regBaryCenter(self,data, _lambda = None, relgap="1.0e-7"):
        
        ''' Calculation of wasserstein barycenter of given 
            images by solving a entropy regularized minimization problem '''
        
        start_time = time.time()
        k = data.shape[0]
        v = np.array(self.single_pmf(data))
        n = v.shape[1]
        d = self.ms_distance(n,data.shape[1])
        
        if not _lambda:
            _lambda = 60/np.median(d.ravel())
        
        #Add variables
        pi= []
        t= []
        mu = cp.Variable(n, nonneg = True)
        for i in range(k):
            pi.append(cp.Variable((n,n), nonneg = True))
            t.append(cp.Variable(nonneg = True))
            
        obj = cp.Minimize((np.sum(t) - (1/_lambda)*np.sum(cp.sum(cp.entr(pi[i])) for i in range(k)))/k)
        
        #Add constraints
        Cons=[]
        for i in range(k):
            Cons.append( t[i] >= (np.ones(n).T @  cp.multiply(d,pi[i]) @ np.ones(n)))#Constraint (1)
            Cons.append( (np.ones(n) @ pi[i]).T == mu)                               #Constraint (2)
            Cons.append( (pi[i] @ np.ones(n)) == v[i])                               #Constraint (3)
            
        self.prob = cp.Problem(obj, constraints= Cons)
        self.result = self.prob.solve(solver=cp.MOSEK,
                                      verbose = True, 
                                      mosek_params = {"MSK_DPAR_INTPNT_CO_TOL_REL_GAP" : relgap })
        self.time = time.time() - start_time
        
        return mu.value
    
    def reset(self):
        self.prob = None
        self.result = None
In [11]:
cvxpy_model = Wasserstein_CVXPY()
cvxpy_result = cvxpy_model.Wasserstein_regBaryCenter(train, relgap = 1.0e-3)
print('\nTime Spent to solve problem with CVXPY: \n {0}'.format(cvxpy_model.time))
print('Time Spent in solver: \n {0}'.format(cvxpy_model.prob.solver_stats.solve_time))
print('The average Wasserstein distance between digits and the barycenter: \n {0}'.format(cvxpy_model.result))
/home/yhk/anaconda3/lib/python3.6/site-packages/ipykernel_launcher.py:116: DeprecationWarning: Calling np.sum(generator) is deprecated, and in the future will give a different result. Use np.sum(np.fromiter(generator)) or the python sum builtin instead.
Problem
  Name                   :                 
  Objective sense        : min             
  Type                   : CONIC (conic optimization problem)
  Constraints            : 4921172         
  Cones                  : 1229312         
  Scalar variables       : 6147346         
  Matrix variables       : 0               
  Integer variables      : 0               

Optimizer started.
Presolve started.
Linear dependency checker started.
Linear dependency checker terminated.
Eliminator started.
Freed constraints in eliminator : 0
Eliminator terminated.
Eliminator - tries                  : 1                 time                   : 0.00            
Lin. dep.  - tries                  : 1                 time                   : 0.94            
Lin. dep.  - number                 : 1               
Presolve terminated. Time: 5.52    
Problem
  Name                   :                 
  Objective sense        : min             
  Type                   : CONIC (conic optimization problem)
  Constraints            : 4921172         
  Cones                  : 1229312         
  Scalar variables       : 6147346         
  Matrix variables       : 0               
  Integer variables      : 0               

Optimizer  - threads                : 24              
Optimizer  - solved problem         : the primal      
Optimizer  - Constraints            : 1138
Optimizer  - Cones                  : 1229312
Optimizer  - Scalar variables       : 3687936           conic                  : 3687936         
Optimizer  - Semi-definite variables: 0                 scalarized             : 0               
Factor     - setup time             : 1.32              dense det. time        : 0.00            
Factor     - ML order time          : 0.01              GP order time          : 0.00            
Factor     - nonzeros before factor : 2.79e+05          after factor           : 3.41e+05        
Factor     - dense dim.             : 0                 flops                  : 1.17e+08        
ITE PFEAS    DFEAS    GFEAS    PRSTATUS   POBJ              DOBJ              MU       TIME  
0   6.3e+02  5.1e+02  2.4e+07  0.00e+00   2.213290906e+07   -1.586952925e+06  1.0e+00  10.80 
1   2.2e+02  1.8e+02  1.4e+07  -9.87e-01  1.997448687e+07   -3.251322359e+06  3.5e-01  13.53 
2   1.4e+02  1.1e+02  1.0e+07  -8.50e-01  1.706336316e+07   -4.132584664e+06  2.2e-01  16.00 
3   8.6e+01  6.9e+01  6.1e+06  -5.35e-01  1.255539057e+07   -4.350508595e+06  1.4e-01  18.47 
4   3.9e+01  3.1e+01  2.3e+06  -5.91e-02  6.634396136e+06   -3.290269560e+06  6.2e-02  21.06 
5   1.6e+01  1.3e+01  6.8e+05  4.56e-01   3.047707328e+06   -1.738025645e+06  2.5e-02  23.59 
6   9.2e+00  7.4e+00  3.2e+05  6.93e-01   1.871307617e+06   -1.107856288e+06  1.5e-02  26.26 
7   3.8e+00  3.0e+00  9.0e+04  7.82e-01   8.078374063e+05   -5.102762310e+05  5.9e-03  28.94 
8   1.0e+00  8.1e-01  1.4e+04  8.51e-01   2.027566219e+05   -1.783122870e+05  1.6e-03  31.50 
9   4.2e-01  3.4e-01  4.0e+03  8.87e-01   8.245390849e+04   -8.500024053e+04  6.6e-04  34.09 
10  1.2e-01  9.9e-02  6.9e+02  9.00e-01   2.254812883e+04   -2.962344930e+04  2.0e-04  36.68 
11  3.6e-02  2.9e-02  1.2e+02  9.02e-01   5.918677286e+03   -1.024336468e+04  5.7e-05  39.24 
12  1.1e-02  8.5e-03  2.0e+01  9.00e-01   1.600641695e+03   -3.421611681e+03  1.7e-05  41.83 
13  3.3e-03  2.6e-03  3.7e+00  9.03e-01   4.338561459e+02   -1.203297297e+03  5.2e-06  44.30 
14  9.2e-04  7.4e-04  5.9e-01  9.15e-01   9.099180087e+01   -3.977689860e+02  1.5e-06  46.90 
15  2.5e-04  2.0e-04  8.5e-02  9.22e-01   1.262938718e+00   -1.351900013e+02  3.9e-07  49.58 
16  7.1e-05  5.7e-05  1.4e-02  9.31e-01   -1.815195892e+01  -5.892312187e+01  1.1e-07  52.15 
17  1.6e-05  1.3e-05  1.6e-03  9.35e-01   -2.349893678e+01  -3.326489098e+01  2.6e-08  54.74 
18  3.2e-06  2.6e-06  1.5e-04  9.38e-01   -2.441820572e+01  -2.642393794e+01  5.1e-09  57.33 
19  7.5e-07  6.0e-07  1.7e-05  9.44e-01   -2.452958761e+01  -2.501494965e+01  1.2e-09  59.85 
20  1.8e-07  1.4e-07  2.0e-06  9.49e-01   -2.454255453e+01  -2.466070613e+01  2.8e-10  62.42 
21  3.6e-08  2.9e-08  1.9e-07  9.48e-01   -2.454079667e+01  -2.456580565e+01  5.7e-11  65.00 
22  9.3e-09  7.5e-09  2.6e-08  9.50e-01   -2.453852207e+01  -2.454516536e+01  1.5e-11  67.53 
23  2.3e-09  1.8e-09  3.3e-09  9.51e-01   -2.453774797e+01  -2.453943362e+01  3.6e-12  70.11 
24  6.9e-10  5.0e-10  4.7e-10  9.51e-01   -2.453754486e+01  -2.453801407e+01  9.8e-13  72.59 
Optimizer terminated. Time: 74.99   


Interior-point solution summary
  Problem status  : PRIMAL_AND_DUAL_FEASIBLE
  Solution status : OPTIMAL
  Primal.  obj: -2.4537544857e+01   nrm: 4e+00    Viol.  con: 1e-08    var: 0e+00    cones: 6e-12  
  Dual.    obj: -2.4538014075e+01   nrm: 8e+02    Viol.  con: 7e-15    var: 1e-08    cones: 0e+00  

Time Spent to solve problem with CVXPY: 
 102.05331134796143
Time Spent in solver: 
 74.98805212974548
The average Wasserstein distance between digits and the barycenter: 
 -24.537544856968044
In [12]:
plt.imshow(np.reshape(cvxpy_result.squeeze(), (28,28)))
plt.title('Barycenter')
plt.show()
In [13]:
plt.figure(figsize=(12,6))

total_t50 = [fusion_model2.time, cvxpy_model.time]
solver_t50 = [fusion_model2.M.getSolverDoubleInfo("optimizerTime"), cvxpy_model.prob.solver_stats.solve_time]

#Total time plot
plt.subplot(1,2,1)
plt.bar(['Fusion', 'CVXPY'], height= total_t50,
        width=0.4, color=(0.3, 0.6, 0.2, 0.5))
plt.ylabel("Total Time (s)")
plt.title("Comparison of Total Time")

#Solver time plot
plt.subplot(1,2,2)
plt.bar(['Fusion','CVXPY'], height=solver_t50,
        width=0.4, color=(0.5, 0.6, 0.9, 0.8))
plt.ylabel("Solver Time (s)")
plt.title("Comparison of Solver Time")
plt.show()

Creative Commons License
This work is licensed under a Creative Commons Attribution 4.0 International License. The MOSEK logo and name are trademarks of Mosek ApS. The code is provided as-is. Compatibility with future release of MOSEK or the Fusion API are not guaranteed. For more information contact our support.