using PyPlot using NtToolBox Cols = (n0,n1) -> sparse(vec(repeat(1:n1, outer=(1,n0))'), vec(reshape(collect(1:n0*n1),n0,n1) ), Int.(ones(n0*n1)) ); Rows = (n0,n1) -> sparse( vec(repeat(1:n0, outer=(1,n1))'), vec(reshape(collect(1:n0*n1),n0,n1)' ), Int.(ones(n0*n1) )); Sigma = (n0,n1) -> [Rows(n0,n1); Cols(n0,n1)]; maxit = 1e4 tol = 1e-9 otransp = (C,p0,p1) -> reshape(perform_linprog(Array(Sigma(length(p0),length(p1))),[vec(p0);vec(p1)], C, maxit, tol), length(p0), length(p1)); n0 = 60 n1 = 80; gauss = (q,a,c) -> a*randn(2, q) + repeat(c', outer=(1,q)) X0 = randn(2,n0)*.3 X1 = [gauss(Base.div(n1,2),.5, [0 1.6]) gauss(Base.div(n1,4),.3, [-1 -1]) gauss(Base.div(n1,4),.3, [1 -1])]; normalize = a-> a/sum(a) p0 = normalize(rand(n0, 1)) p1 = normalize(rand(n1, 1)); myplot = (x,y,ms,col) -> scatter(x,y, s=ms*20, edgecolors="k", c=col, linewidths=2); figure(figsize = (10,7)) axis("off") for i in 1:length(p0) myplot(X0[1,i], X0[2,i], p0[i]*length(p0)*10, "b") end for i in 1:length(p1) myplot(X1[1,i], X1[2,i], p1[i]*length(p1)*10, "r") xlim(minimum(X1[1,:])-.1,maximum(X1[1,:])+.1) ylim(minimum(X1[2,:])-.1,maximum(X1[2,:])+.1) end C = repeat( sum(X0.^2,1)', outer=(1, n1) ) + repeat( sum(X1.^2,1), outer=(n0,1) ) - 2*X0'*X1; Gamma = otransp(C, p0, p1); println("Number of non-zero: $(length(Gamma[Gamma.>0])) (n0 + n1-1 = $(n0 + n1-1))" ) println("Constraints deviation (should be 0): $(norm(sum(Gamma,2)-vec(p0))), $(norm(sum(Gamma, 1)'-vec(p1)))") I,J,Gammaij = findnz(Gamma) length(Gammaij) figure(figsize =(15,10)) tlist = collect(linspace(0, 1, 6)) for i in 1:length(tlist) t = tlist[i] Xt = (1-t)*X0[:,I] + t*X1[:,J] subplot(2,3,i) axis("off") for j in 1:length(Gammaij) myplot(Xt[1,j],Xt[2,j],Gammaij[j]*length(Gammaij)*6,[t,0,1-t]) end title("t = $t") xlim(minimum(X1[1,:])-.1,maximum(X1[1,:])+.1) ylim(minimum(X1[2,:])-.1,maximum(X1[2,:])+.1) end n0 = 40 n1 = n0; X0 = randn(2,n0)*.3 X1 = [gauss(Base.div(n1,2),.5, [0 1.6]) gauss(Base.div(n1,4),.3, [-1 -1]) gauss(Base.div(n1,4),.3, [1 -1])]; p0 = ones(n0)/n0 p1 = ones(n1)/n1; C = repeat( sum(X0.^2,1)', outer=(1,n1) ) + repeat( sum(X1.^2,1), outer=(n0,1) ) - 2*X0'*X1; figure(figsize = (10,7)) axis("off") myplot(X0[1,:],X0[2,:],10,"b") myplot(X1[1,:],X1[2,:],10,"r") xlim(minimum(X1[1,:])-.1,maximum(X1[1,:])+.1) ylim(minimum(X1[2,:])-.1,maximum(X1[2,:])+.1); Gamma = otransp(C, p0, p1); figure(figsize = (5,5)) imageplot(Gamma) I,J = findn(Gamma) figure(figsize = (10,7)) axis("off") for k in 1:length(I) h = plot([X0[1,I[k]]; X1[1,J[k]]],[X0[2,I[k]]; X1[2,J[k]]],"k", lw = 2) end myplot(X0[1,:],X0[2,:],10,"b") myplot(X1[1,:],X1[2,:],10,"r") xlim(minimum(X1[1,:])-.1,maximum(X1[1,:])+.1) ylim(minimum(X1[2,:])-.1,maximum(X1[2,:])+.1);