using OptimalTransport using Distances using Distributions using StatsPlots using LinearAlgebra using Random Random.seed!(1234); μ = Normal(0, 1) N = 10 ν = Poisson(N); c(x, y) = (abs(x - y))^2 # could have used `sqeuclidean` from `Distances.jl` T = ot_plan(c, μ, ν); p1 = plot(μ; label='μ') p1 = plot!(ν; marker=:circle, label='ν') p2 = plot(-2:0.1:2, T(-2:0.1:2); label="Monge map", color=:green, legend=:topleft) plot(p1, p2) ot_cost(c, μ, ν) wasserstein(μ, ν; p=2) M = 15 μ = DiscreteNonParametric(1.5rand(M), fill(1 / M, M)) N = 10 ν = DiscreteNonParametric(1.5rand(N) .+ 2, fill(1 / N, N)) γ = ot_plan(sqeuclidean, μ, ν); function curve(x1, x2, y1, y2) a = min(y1, y2) b = (y1 - y2 + a * (x1^2 - x2^2)) / (x1 - x2) c = y1 + a * x1^2 - b * x1 f(x) = -a * x^2 + b * x + c return f end p = plot(μ; marker=:circle, label='μ') p = plot!(ν; marker=:circle, label='ν', ylims=(0, 0.2)) for i in 1:M, j in 1:N if γ[i, j] > 0 transport = curve(μ.support[i], ν.support[j], 1 / M, 1 / N) x = range(μ.support[i], ν.support[j]; length=100) p = plot!(x, transport.(x); color=:green, label=nothing, alpha=0.5) end end p ot_cost(sqeuclidean, μ, ν)