using Base.Meta: parse using SpecialFunctions linspace(a,b,L) = range(a, stop=b, length=L) using PyPlot using PyCall using SymPy # https://docs.sympy.org/latest/modules/printing.html#sympy.printing.julia.julia_code const julia_code = sympy[:julia_code] @show versioninfo() println() @show PyCall.pyprogramname @show PyCall.pyversion @show PyCall.conda @show PyCall.libpython @show sympy[:__version__]; y = symbols("y", real=true) t = symbols("t", positive=true) μ, μ₀ = symbols("μ μ₀", real=true) Y = symbols("Y1:4", real=true) T = symbols("T1:4", positive=true) τ, λ₀, a₀, b₀ = symbols("τ λ₀ a₀ b₀", positive=true) σ, σ² = symbols("σ σ²", positive=true) θ, α = symbols("θ α", positive=true); pdf_Normal(mu, s2, t) = exp(-(t-mu)^2/(2*s2))/√(2*PI*s2) @show pdf_Normal(μ, σ², y) pdf_Gamma(alpha, theta, t) = exp(-t/theta)*t^(alpha-1)/(gamma(alpha)*theta^alpha) @show pdf_Gamma(α, θ, y) p_y = pdf_Normal(μ, 1/τ, y) @show p_y p_μ = pdf_Normal(μ₀, 1/(λ₀*τ), μ) @show p_μ p_τ = pdf_Gamma(a₀, 1/b₀, τ) @show p_τ I1 = integrate(p_μ, (μ, -oo, oo)) @show I1 I2 = integrate(p_τ, (τ, 0, oo)) @show I2 I2 = simplify(I2) Z1_μ = simplify(integrate(p_y*p_μ, (μ, -oo, oo))) F1_μ = simplify(-log(Z1_μ)) F1_0_μ = F1_μ(μ₀=>0, y=>t) F1_0_μ = expand(F1_0_μ) Z1_0_μ = simplify(exp(-F1_0_μ)) Z1_0 = integrate(Z1_0_μ*p_τ, (τ, 0, oo)) log_p = sum(log(p_y(y=>x)) for x in Y) + log(p_μ) + log(p_τ) log_p = simplify(log_p) @show log_p log_p_for_μ = integrate(diff(log_p, μ), μ) @show log_p_for_μ expand(log_p_for_μ) collect(expand(log_p_for_μ), μ) log_p_for_τ = integrate(diff(log_p, τ), τ) @show log_p_for_τ expand(log_p_for_τ) log_p_for_τ = collect(expand(log_p_for_τ), τ) log_q1_μ = integrate(log_p_for_μ * p_τ, (τ, 0, oo)) @show log_q1_μ log_q1_μ = collect(expand(log_q1_μ), μ) @show log_q1_μ q1_μ = exp(log_q1_μ) z1_μ = simplify(integrate(q1_μ, (μ, -oo, oo))) @show z1_μ q1_μ = 1/z1_μ * exp(log_q1_μ) @show q1_μ log_q1_τ = integrate(log_p_for_τ * p_μ, (μ, -oo, oo)) @show log_q1_τ log_q1_τ = integrate(diff(log_q1_τ, τ), τ) @show log_q1_τ log_q1_τ = collect(log_q1_τ, τ) @show log_q1_τ q1_τ = logcombine(exp(log_q1_τ)) @show q1_τ z1_τ = integrate(q1_τ, (τ, 0, oo)) @show z1_τ q1_τ = 1/z1_τ * q1_τ @show q1_τ q2_τ = logcombine(exp(log_q1_τ)) @show q2_τ q2_0_τ = subs(q2_τ, (μ₀, 0), zip(Y,T)...) z2_0_τ = integrate(q2_0_τ, (τ, 0, oo)) @show z2_0_τ z2_0_τ = simplify(z2_0_τ) @show z2_0_τ q2_0_τ = 1/z2_0_τ * q2_0_τ @show q2_0_τ q1_τ = logcombine(exp(log_q1_τ)) display(q1_τ) coef = coeff(collect(log_q1_τ, τ), τ) display(coef) sol = solve(diff(coef, Y[1]), Y[1])[1] display(sol) #-> μ₀ replacements = [(var, sol) for var in Y] display(subs(coef, replacements...)) #-> -b₀ ξ = symbols("ξ", positive=true) z1_τ = simplify(integrate(τ^(a₀+1)*exp(-ξ*τ), (τ, 0, oo))) z1_τ = subs(z1_τ, (ξ, -coef)) @show z1_τ q1_τ = 1/z1_τ * q1_τ q1_τ data = [1.1, 1.0, 1.3] replacements = [(a₀, 1.0), (b₀, 1.0), (μ₀, 0.0), (λ₀, 1.0), zip(Y,data)...] log_p_for_μ_subs = subs(log_p_for_μ, replacements...) log_p_for_τ_subs = subs(log_p_for_τ, replacements...) [log_p_for_μ_subs, log_p_for_τ_subs] q_τ = N(subs(p_τ, replacements...)) q_τ q_μ = Sym(1) for i in 1:7 log_q_μ = N(integrate(log_p_for_μ_subs * q_τ, (τ, 0, oo))) z_q_μ = N(integrate(exp(log_q_μ), (μ, -oo, oo))) q_μ = 1/z_q_μ * exp(log_q_μ) log_q_τ = N(integrate(log_p_for_τ_subs * q_μ, (μ, -oo, oo))(π=>float(π))) # (π=>float(π) が必要なのはちょっと嫌) z_q_τ = N(integrate(exp(log_q_τ), (τ, 0, oo))) q_τ = 1/z_q_τ * exp(log_q_τ) display([q_μ, q_τ]) end q_μ*q_τ julia_code(q_μ*q_τ) "$(q_μ*q_τ)" delta = 0.05 μs = -0.4:delta:2.0 τs = 0.0:delta:5.5 eval(parse("f(μ,τ) = $(q_μ*q_τ)")) @time c_f = f.(μs', τs); figure(figsize=(5,4)) CS = contour(μs, τs, c_f) clabel(CS, inline=1, fontsize=10) xlabel("μ") ylabel("τ") grid(ls=":") figure(figsize=(6.4, 4)) pcolormesh(μs, τs, c_f, cmap="CMRmap") colorbar() xlabel("μ") ylabel("τ") grid(ls=":") p3 = simplify(exp(log_p)) Z3_μ = integrate(p3, (μ, -oo, oo)) Z3_μ = simplify(Z3_μ) F3_μ = expand(-log(Z3_μ)) coef = coeff(collect(F3_μ, τ), τ) C = simplify(exp(-simplify(F3_μ - τ*coef))/τ^(a₀+1/Sym(2))) sol1 = solve(diff(coef, Y[1]), Y[1])[1] coef1 = simplify(subs(coef, (Y[1], sol1))) sol2 = solve(diff(coef1, Y[2]), Y[2])[1] coef2 = simplify(subs(coef1, (Y[2], sol2))) sol = solve(diff(coef2, Y[3]), Y[3])[1] simplify(subs(coef, ((y, sol) for y in Y)...)) # -> b₀ > 0 and hence coef > 0. ξ = symbols("ξ", positive=true) Z3 = simplify(integrate(τ^(a₀+1)*exp(-ξ*τ), (τ, 0, oo))) Z3 = Z3(ξ=>coef) Z3 = simplify(C*simplify(Z3)) q3 = p3/Z3 q3 = q3(ξ=>coef) # posterior Nq3 = N(subs(q3, replacements...)) # numerical posterior simplify(q_μ*q_τ) eval(parse("g(μ,τ) = $Nq3")) # numerical posterior @time c_g = g.(μs', τs); figure(figsize=(10,4)) subplot(121) CS = contour(μs, τs, c_f) clabel(CS, inline=1, fontsize=10) xlabel("μ") ylabel("τ") grid(ls=":") title("variational approximation") subplot(122) CS = contour(μs, τs, c_g) clabel(CS, inline=1, fontsize=10) xlabel("μ") ylabel("τ") grid(ls=":") title("exact posterior") figure(figsize=(10, 3)) subplot(121) pcolormesh(μs, τs, c_f, cmap="CMRmap") colorbar() xlabel("μ") ylabel("τ") grid(ls=":") title("variational approximation") subplot(122) pcolormesh(μs, τs, c_g, cmap="CMRmap") colorbar() xlabel("μ") ylabel("τ") grid(ls=":") title("exact posterior")