using OptimalTransport using Distances using LogExpFunctions using Optim using Plots using StatsBase using ReverseDiff using LinearAlgebra using Logging support = range(-1, 1; length=64) C = pairwise(SqEuclidean(), support'); function E(ρ; m=1) if m == 1 return sum(xlogx.(ρ)) - sum(ρ) elseif m > 1 return dot(ρ, @. (ρ^(m - 1) - m) / (m - 1)) end end; ψ(x) = 10 * (x - 0.5)^2 * (x + 0.5)^2; plot(support, ψ.(support); color="black", label="Scalar potential") Ψ = ψ.(support); τ = 0.05 ε = 0.01 K = @. exp(-C / ε); H(x) = x > 0 ρ0 = @. H(support + 0.25) - H(support - 0.25) ρ0 = ρ0 / sum(ρ0) plot(support, ρ0; label="Initial condition ρ0", color="blue") function G_fpe(ρ, ρ0, τ, ε, C) return sinkhorn2(ρ, ρ0, C, ε; regularization=true, maxiter=250) + τ * (dot(Ψ, ρ) + E(ρ)) end; function step(ρ0, τ, ε, C, G) # only print error messages obj = u -> G(softmax(u), ρ0, τ, ε, C) opt = with_logger(SimpleLogger(stderr, Logging.Error)) do optimize( obj, ones(size(ρ0)), LBFGS(), Optim.Options(; iterations=50, g_tol=1e-6); autodiff=:forward, ) end return softmax(Optim.minimizer(opt)) end N = 10 ρ = similar(ρ0, size(ρ0, 1), N) ρ[:, 1] = ρ0 for i in 2:N @info i ρ[:, i] = step(ρ[:, i - 1], τ, ε, C, G_fpe) end colors = range(colorant"red"; stop=colorant"blue", length=N) plot( support, ρ; title=raw"$F(\rho) = \langle \psi, \rho \rangle + \langle \rho, \log(\rho) \rangle$", palette=colors, legend=nothing, ) function G_pme(ρ, ρ0, τ, ε, C) return sinkhorn2(ρ, ρ0, C, ε; regularization=true, maxiter=250) + τ * (dot(Ψ, ρ) + E(ρ; m=2)) end; N = 10 ρ = similar(ρ0, size(ρ0, 1), N) ρ[:, 1] = ρ0 for i in 2:N ρ[:, i] = step(ρ[:, i - 1], τ, ε, C, G_pme) end plot( support, ρ; title=raw"$F(\rho) = \langle \psi, \rho \rangle + \langle \rho, \rho - 1\rangle$", palette=colors, legend=nothing, ) E_dual(u, m::Val{1}) = sum(exp.(u)) function E_dual(u, m::Val{2}) return dot(u / 2 .+ 1, u / 2 .+ 1) end; function G_dual_fpe(u, ρ0, τ, ε, K) return OptimalTransport.Dual.ot_entropic_semidual(ρ0, u, ε, K) + τ * E_dual(-u / τ - Ψ, Val(1)) end; function step(ρ0, τ, ε, K, G) obj = u -> G(u, ρ0, τ, ε, K) opt = optimize( obj, (∇, u) -> ReverseDiff.gradient!(∇, obj, u), zeros(size(ρ0)), LBFGS(), Optim.Options(; iterations=250, g_tol=1e-6), ) return OptimalTransport.Dual.getprimal_ot_entropic_semidual( ρ0, Optim.minimizer(opt), ε, K ) end; ρ = similar(ρ0, size(ρ0, 1), N) ρ[:, 1] = ρ0 for i in 2:N ρ[:, i] = step(ρ[:, i - 1], τ, ε, K, G_dual_fpe) end colors = range(colorant"red"; stop=colorant"blue", length=N) plot( support, ρ; title=raw"$F(\rho) = \langle \psi, \rho \rangle + \langle \rho, \log(\rho) \rangle$", palette=colors, legend=nothing, ) function G_dual_pme(u, ρ0, τ, ε, K) return OptimalTransport.Dual.ot_entropic_semidual(ρ0, u, ε, K) + τ * E_dual(-u / τ - Ψ, Val(2)) end ρ = similar(ρ0, size(ρ0, 1), N) ρ[:, 1] = ρ0 for i in 2:N @info i ρ[:, i] = step(ρ[:, i - 1], τ, ε, K, G_dual_pme) end colors = range(colorant"red"; stop=colorant"blue", length=N) plot( support, ρ; title=raw"$F(\rho) = \langle \psi, \rho \rangle + \langle \rho, \rho - 1\rangle$", palette=colors, legend=nothing, )