using LinearAlgebra using Plots pyplot() using JuMP using MosekTools #Define the number of images for the barycenter calculation. n = 20 #Read the images from the file. function read_idx(filename) f = open(filename,"r") data_layout = zeros(UInt8,4) readbytes!(f,data_layout,4) data_zero = reinterpret(UInt16,data_layout[1:2]) data_type,data_dimensions = reinterpret(UInt8,data_layout[3:4]) data_shape = Int32[] for i = 1:data_layout[4] s = zeros(UInt8,4) readbytes!(f,s,4) s = map(hton,reinterpret(Int32,s)) push!(data_shape,s[1]) end idx_data = zeros(UInt8,cumprod(data_shape)[length(data_shape)]) read!(f,idx_data) idx_data = reshape(idx_data,Tuple(reverse(data_shape))) return(idx_data) end data = read_idx("train-images-idx3-ubyte") labels = read_idx("train-labels-idx1-ubyte") #Select the images. mask = labels .== 1 train_ones = data[:,:,mask] train = train_ones[:,:,1:n] x = [i for i=1:28] y = reverse(x) f,ax = PyPlot.plt.subplots(2,5,sharey=true,sharex=true,figsize=(10,5)) PyPlot.plt.xticks([5,10,15,20,25]) for i = 1:10 rand_pick = rand(1:size(train_ones)[3]) ax[i].pcolormesh(x,y,transpose(train_ones[:,:,rand_pick])) end function single_pmf(data) #Takes a list of images and extracts the probability mass function. v = vec(data[:,:,1]) v = v./cumsum(v)[length(v)] for im_k in 2:size(data)[3] image = data[:,:,im_k] arr = vec(image) v_size = size(arr)[1] v = hcat(v, arr./cumsum(arr)[length(arr)]) end return v,size(v)[1] end function ms_distance(m,n) #Squared Euclidean distance calculation between the pixels. d = ones(m,m) coor_I = [] for c_i in 1:n append!(coor_I,ones(Int,n).*c_i) end coor_J = repeat(1:n,n) coor = hcat(coor_I,coor_J) for i in 1:m for j in 1:m d[i,j] = norm(coor[i,:]-coor[j,:]).^2 end end return d end function wasserstein_barycenter(data) M= direct_model(Mosek.Optimizer()) if length(size(data))==3 K = size(data)[3] else K = 1 end v,N = single_pmf(data) d = ms_distance(N,size(data)[2]) #Define indices M_i = 1:N M_j = 1:N M_k = 1:K #Adding variables M_pi = @variable(M, M_pi[i = M_i, j = M_j, k = M_k] >= 0.0) M_mu = @variable(M, M_mu[i = M_i] >= 0.0) #Adding constraints @constraint(M, c3_expr[k = M_k, j = M_j], sum(M_pi[:,j,k]) == v[j,k]) @constraint(M, c2_expr[k = M_k, i = M_i], sum(M_pi[i,:,k]) == M_mu[i]) #Objective W_obj = @objective(M, Min, sum(d[i,j]*M_pi[i,j,k] for i=M_i,j=M_j,k=M_k)/K) return M,M_mu end function run_model(data) @time begin M,M_mu = wasserstein_barycenter(data) optimize!(M) end println("Solution status = ",termination_status(M)) println("Primal objective value = ",objective_value(M)) mu_level = value.(M_mu) return mu_level end function show_barycenter(bary_center) bary_center = reshape(bary_center,(28,28)) x = [i for i=1:28] y = reverse(x) PyPlot.plt.pcolormesh(x,y,transpose(bary_center)) PyPlot.plt.title("Non-regularized Wasserstein Barycenter") PyPlot.plt.show() end bary_center = run_model(train) println("******") show_barycenter(bary_center)