# load all relevant modules
import numpy as np
import matplotlib.pylab as pl
import ot
n=100 # nb bins
a=ot.datasets.get_1D_gauss(n,m=20,s=20) # m= mean, s= std
b=ot.datasets.get_1D_gauss(n,m=60,s=60)
# bin positions
x=np.arange(n,dtype=np.float64)
# loss matrix
M=ot.dist(x.reshape((n,1)),x.reshape((n,1)))
M/=M.max()
pl.figure(1)
pl.plot(x,a,'b',label='Source distribution')
pl.plot(x,b,'r',label='Target distribution')
pl.legend()
pl.show()
# plot distributions and loss matrix
pl.figure(2)
ot.plot.plot1D_mat(a,b,M,'Cost matrix M')
G0=ot.emd(a,b,M)
pl.figure(3)
ot.plot.plot1D_mat(a,b,G0,'OT matrix G0')
# reg parameter
lambd=1e-3
Gs=ot.sinkhorn(a,b,M,lambd)
pl.figure(4)
ot.plot.plot1D_mat(a,b,Gs,'OT matrix Sinkhorn')