Ground Truthからtoy dataをサンプリングする
using PDMats
using Distributions
using LinearAlgebra
using Plots
using ProgressBars
┌ Info: Precompiling Plots [91a5bcdd-55d7-5caf-9e0b-520d859cae80] └ @ Base loading.jl:1242 ┌ Warning: Module JSON with build ID 1290864591119004 is missing from the cache. │ This may mean JSON [682c06a0-de6a-54ab-a142-c8b1cf79cde6] does not support precompilation but is imported by a module that does. └ @ Base loading.jl:1000 ┌ Info: Skipping precompilation since __precompile__(false). Importing Plots [91a5bcdd-55d7-5caf-9e0b-520d859cae80]. └ @ Base loading.jl:1017
μ1_gt = [0.; 5.]
μ2_gt = [-5.; 2.]
μ3_gt = [5.; -2.]
μ4_gt = [8.; -3.]
μ5_gt = [-4.; 6.]
Σ1_gt = [1. 0.; 0. 2.]
Σ2_gt = [.5 .75; .75 2.]
Σ3_gt = [.5 .1; .1 .5]
Σ4_gt = [.2 .2; .2 .5]
Σ5_gt = [2. 1.; 1. 2.5]
π_gt = [.3, .2, .15, .2, .15]
mixture_components = MvNormal[
MvNormal(μ1_gt, Σ1_gt),
MvNormal(μ2_gt, Σ2_gt),
MvNormal(μ3_gt, Σ3_gt),
MvNormal(μ4_gt, Σ4_gt),
MvNormal(μ5_gt, Σ5_gt)
];
model_gt = MixtureModel(mixture_components, π_gt)
MixtureModel{MvNormal}(K = 5) components[1] (prior = 0.3000): FullNormal( dim: 2 μ: [0.0, 5.0] Σ: [1.0 0.0; 0.0 2.0] ) components[2] (prior = 0.2000): FullNormal( dim: 2 μ: [-5.0, 2.0] Σ: [0.5 0.75; 0.75 2.0] ) components[3] (prior = 0.1500): FullNormal( dim: 2 μ: [5.0, -2.0] Σ: [0.5 0.1; 0.1 0.5] ) components[4] (prior = 0.2000): FullNormal( dim: 2 μ: [8.0, -3.0] Σ: [0.2 0.2; 0.2 0.5] ) components[5] (prior = 0.1500): FullNormal( dim: 2 μ: [-4.0, 6.0] Σ: [2.0 1.0; 1.0 2.5] )
N = 1000
x_train = rand(model_gt, N)
scatter(x_train[1, :], x_train[2, :])
struct Gauss
μ::Vector{Float64} # Mean Vector
Λ::Matrix{Float64} # Precision Matrix
end
struct DPGMM
D::Int
K::Int
mixing_coeff::Vector{Float64}
components::Vector{Gauss}
end
DPGMMの可視化用の関数
function plot_contours(gmm::DPGMM)
mvnormals = MvNormal[
MvNormal(
gmm.components[k].μ,
PDMat(Symmetric(inv(gmm.components[k].Λ)))
) for k in 1:gmm.K]
for k in 1:gmm.K
contour!(-10:0.05:10, -10:0.05:10, (x,y)->pdf(mvnormals[k], [x,y]), color=k, colorbar=:none)
end
end
plot_contours (generic function with 1 method)
gaussians_gt = Gauss[
Gauss(μ1_gt, inv(Σ1_gt)),
Gauss(μ2_gt, inv(Σ2_gt)),
Gauss(μ3_gt, inv(Σ3_gt)),
Gauss(μ4_gt, inv(Σ4_gt)),
Gauss(μ5_gt, inv(Σ5_gt))
];
gmm_gt = DPGMM(2, 5, π_gt, gaussians_gt)
DPGMM(2, 5, [0.3, 0.2, 0.15, 0.2, 0.15], Gauss[Gauss([0.0, 5.0], [1.0 0.0; 0.0 0.5]), Gauss([-5.0, 2.0], [4.571428571428571 -1.7142857142857142; -1.7142857142857144 1.1428571428571428]), Gauss([5.0, -2.0], [2.0833333333333335 -0.41666666666666674; -0.41666666666666674 2.0833333333333335]), Gauss([8.0, -3.0], [8.333333333333334 -3.3333333333333335; -3.3333333333333335 3.3333333333333335]), Gauss([-4.0, 6.0], [0.625 -0.25; -0.25 0.5])])
Ground Truthのプロット
plt = plot()
plot_contours(gmm_gt)
scatter!(x_train[1, :], x_train[2, :], legend=:none)
# Number of Classes
K = 1;
# Number of Dimensions
D = 2;
α = 10.0
α0 = ones(K) * α / K
μ0 = zeros(D)
β0 = 1e-3
ν0 = convert(Float64, D)
W0 = Matrix{Float64}(I, D, D) * 5.0
2×2 Array{Float64,2}: 5.0 0.0 0.0 5.0
# Parameters of Student's t-dist
μ_s = μ0
Λ_s = (1 - D + ν0) * β0 / (1 + β0) * W0
ν_s = 1 - D + ν0
1.0
潜在変数の初期値は事前分布からサンプリングする
mixing_coeff = rand(Dirichlet(α0))
μk = zeros(K, D)
Λk = zeros(K, D, D)
for k in 1:K
Λk[k, :, :] = rand(Wishart(ν0, W0))
μk[k, :] = rand(MvNormal(μ0, PDMat(Symmetric(inv(β0 * Λk[k, :, :])))))
end
gaussians = Gauss[Gauss(μk[k, :], Λk[k, :, :]) for k in 1:K]
1-element Array{Gauss,1}: Gauss([16.64461721247163, -7.3290740500053895], [4.0163358060189305 -2.2064800529147965; -2.2064800529147965 5.503413115825852])
gmm = DPGMM(D, K, mixing_coeff, gaussians)
hidden_state = rand(Categorical(gmm.mixing_coeff), N);
データとクラスタの初期割当をプロット
plt = plot()
scatter!(x_train[1, :], x_train[2, :], markercolor=hidden_state)
事後確率最大化に使う
function compute_log_ewens(gmm::DPGMM, hidden_state::Array{Int64, 1})
log_lik = 0.0
log_lik += gmm.K * log(α)
state_ids = Array{Int64,1}(1:gmm.K)
n_i = [count(x -> x == i, hidden_state) for i in state_ids]
for k in 1:gmm.K
log_lik += sum(log.(collect(1:n_i[k])))
end
log_AF = sum(log.(α .+ collect(0:N-1)))
log_lik -= log_AF
log_lik -= sum(log.(collect(1:gmm.K)))
return log_lik
end
function compute_log_likelihood(gmm::DPGMM, x_train::Array{Float64}, hidden_state::Array{Int64, 1})
log_lik = 0.0
mvnormals = MvNormal[
MvNormal(
gmm.components[k].μ,
PDMat(Symmetric(inv(gmm.components[k].Λ)))
) for k in 1:gmm.K
]
for n in 1:N
log_lik += logpdf(mvnormals[hidden_state[n]], x_train[:, n])
# log_lik += logpdf(Categorical(gmm.mixing_coeff), hidden_state[hidden_state[n]])
end
# Log likelihood of Normal-Wishart
for k in 1:gmm.K
log_lik += logpdf(MvNormal(μ0, PDMat(Symmetric(inv(gmm.components[k].Λ)))), gmm.components[k].μ)
log_lik += logpdf(Wishart(ν0, W0), gmm.components[k].Λ)
end
log_lik += compute_log_ewens(gmm, hidden_state)
# log_lik += logpdf(Dirichlet(α0), gmm.mixing_coeff)
return log_lik
end
compute_log_likelihood (generic function with 1 method)
function compute_likelihood(dists::Array{T}, x_train::Array{Float64}) where T <: Distribution
# Number of dists
K = size(dists, 1)
# Number of data
N = size(x_train, 2)
# temporary variable
p_xi = zeros(N, K)
for k in 1:K
p_xi[:, k] = pdf(dists[k], x_train)
end
return p_xi
end
function sample_hidden_state(gmm::DPGMM, x_train::Array{Float64})
# Number of data
N = size(x_train, 2)
for n in 1:N
state_ids = Array{Int64,1}(1:gmm.K)
m = [count(x -> x == i, hidden_state) for i in state_ids]
m[hidden_state[n]] -= 1
if m[hidden_state[n]] == 0
## Remove empty category
### 消されたクラスタより後ろの番号のクラスタについて,番号をつめる
hidden_state[hidden_state .> hidden_state[n]] .-= 1
### 消されたクラスタを削除する
m = m[state_ids .!= hidden_state[n]]
remained_components = [gmm.components[k] for k in 1:gmm.K if k != hidden_state[n]]
gmm = DPGMM(gmm.D, gmm.K-1, gmm.mixing_coeff, remained_components)
end
new_mixing_coeff = [m; α]
new_mixing_coeff = new_mixing_coeff ./ (sum(new_mixing_coeff) - 1 + α)
# Compute likelihoods
mvnormals = [MvNormal(
gmm.components[k].μ,
PDMat(Symmetric(inv(gmm.components[k].Λ)))
) for k in 1:gmm.K]
mvnormals = [mvnormals; MvTDist(ν_s, μ_s, Λ_s)] # Add dist for new category
likelihood = zeros(gmm.K+1)
for k in 1:(gmm.K+1)
likelihood[k] = pdf(mvnormals[k], x_train[:, n])
end
new_mixing_coeff = new_mixing_coeff .* likelihood
new_mixing_coeff = new_mixing_coeff ./ sum(new_mixing_coeff)
# Resample hidden state
hidden_state[n] = rand(Categorical(new_mixing_coeff))
if hidden_state[n] == gmm.K+1
# Add new category
Λ_new = rand(Wishart(ν0, W0))
μ_new = rand(MvNormal(μ0, PDMat(Symmetric(inv(β0 * Λ_new)))))
new_components = Gauss[gmm.components; Gauss(μ_new, Λ_new)]
gmm = DPGMM(gmm.D, gmm.K+1, gmm.mixing_coeff, new_components)
else
gmm = DPGMM(gmm.D, gmm.K, gmm.mixing_coeff, gmm.components)
end
end
return hidden_state, gmm
end
sample_hidden_state (generic function with 1 method)
function sample_mixing_coeff(gmm::DPGMM, hidden_state::Array{Int64,1})
state_ids = Array{Int64,1}(1:gmm.K)
m = [count(x -> x == i, hidden_state) for i in state_ids]
α_post = ones(gmm.K) / α + m
new_mixing_coeff = rand(Dirichlet(α_post))
return DPGMM(gmm.D, gmm.K, new_mixing_coeff, gmm.components)
end
sample_mixing_coeff (generic function with 1 method)
function sample_gauss(gmm::DPGMM, x_train::Array{Float64}, hidden_state::Array{Int64,1})
new_μk = zeros(gmm.K, gmm.D)
new_Λk = zeros(gmm.K, gmm.D, gmm.D)
# temporary variables
state_ids = Array{Int64,1}(1:gmm.K)
m = [count(x -> x == i, hidden_state) for i in state_ids]
for k in 1:gmm.K
# Get data in the k-th class
x_k = x_train[:, hidden_state .== k]
# Compute the covariance matrix of the k-th class
S_k = x_k * x_k'
# Compute parameters of posterior distributions
β_post = β0 + m[k]
μ_post = (sum(x_k, dims=2) + β0 * μ0) / β_post
μ_post = μ_post[:, 1]
ν_post = ν0 + m[k]
W_post_inv = S_k + β0 * μ0 * μ0' - β_post * μ_post * μ_post' + inv(W0)
W_post = PDMat(Symmetric(inv(W_post_inv)))
# Sample parameters from posterior distributions
new_Λk[k, :, :] = rand(Wishart(ν_post, W_post))
new_μk[k, :] = rand(MvNormal(μ_post, PDMat(Symmetric(inv(β_post * new_Λk[k, :, :])))))
end
new_components = Gauss[Gauss(new_μk[k, :], new_Λk[k, :, :]) for k in 1:gmm.K]
return DPGMM(gmm.D, gmm.K, gmm.mixing_coeff, new_components)
end
sample_gauss (generic function with 1 method)
アニメーション生成の参考
# 対数同時確率を最大にするパラメタとクラス割当て,そのときの対数事後確率を保存する
gmm_max = deepcopy(gmm)
hidden_state_max = copy(hidden_state)
loglik_max = compute_log_likelihood(gmm, x_train, hidden_state)
-1.2018555290217667e6
loglik_list = Array{Float64, 1}()
0-element Array{Float64,1}
anim = Animation()
# 初期状態のプロット
plt = scatter(x_train[1, :], x_train[2, :], markercolor=hidden_state, legend=:none)
plot_contours(gmm)
title!("Iteration: 0")
frame(anim, plt)
for iter in tqdm(1:1000)
# クラスタ割り当てのサンプリング
hidden_state, gmm = sample_hidden_state(gmm, x_train)
# 混合係数のサンプリング
gmm = sample_mixing_coeff(gmm, hidden_state)
# ガウス分布のパラメタのサンプリング
gmm = sample_gauss(gmm, x_train, hidden_state)
loglik = compute_log_likelihood(gmm, x_train, hidden_state)
if loglik > loglik_max
gmm_max = deepcopy(gmm)
hidden_state_max = copy(hidden_state)
loglik_max = compute_log_likelihood(gmm, x_train, hidden_state)
end
push!(loglik_list, loglik)
# ギプスサンプリングの現在の状態をプロット
if iter % 50 == 0
plt = plot()
scatter!(x_train[1, :], x_train[2, :], markercolor=hidden_state, legend=:none)
plot_contours(gmm)
title!("Iteration: $iter")
frame(anim, plt)
end
end
gif(anim, "gs_dpgmm.gif", fps=2);
100.00%┣█████████████████████████████████████████████████████████▉┫ 1000/1000 01:42<00:00, 9.80 it/s]┫ 2/1000 00:06<01:46:04, 0.16 it/s]5/1000 00:07<29:39, 0.56 it/s]7/1000 00:08<20:43, 0.80 it/s]8/1000 00:08<18:07, 0.91 it/s]13/1000 00:09<11:40, 1.41 it/s]18/1000 00:09<08:54, 1.84 it/s]19/1000 00:09<08:32, 1.92 it/s]23/1000 00:10<07:18, 2.23 it/s]25/1000 00:10<06:51, 2.37 it/s]26/1000 00:10<06:39, 2.44 it/s]30/1000 00:11<05:57, 2.72 it/s]31/1000 00:11<05:47, 2.79 it/s]34/1000 00:11<05:24, 2.99 it/s]35/1000 00:11<05:16, 3.05 it/s]37/1000 00:11<05:03, 3.18 it/s]38/1000 00:11<04:57, 3.24 it/s]40/1000 00:12<04:46, 3.35 it/s]42/1000 00:12<04:36, 3.47 it/s]┫ 44/1000 00:12<04:27, 3.58 it/s]┫ 52/1000 00:17<05:17, 2.99 it/s]55/1000 00:17<05:03, 3.12 it/s]57/1000 00:18<04:55, 3.20 it/s]┫ 58/1000 00:18<04:50, 3.25 it/s]61/1000 00:18<04:39, 3.36 it/s]66/1000 00:18<04:23, 3.56 it/s]70/1000 00:19<04:10, 3.72 it/s]72/1000 00:19<04:05, 3.79 it/s]76/1000 00:19<03:55, 3.93 it/s]78/1000 00:19<03:50, 4.01 it/s]82/1000 00:20<03:41, 4.15 it/s]84/1000 00:20<03:37, 4.23 it/s]86/1000 00:20<03:32, 4.31 it/s]94/1000 00:20<03:18, 4.58 it/s]96/1000 00:20<03:14, 4.65 it/s]100/1000 00:23<03:32, 4.24 it/s]102/1000 00:24<03:29, 4.29 it/s]106/1000 00:24<03:23, 4.40 it/s]112/1000 00:24<03:14, 4.57 it/s]126/1000 00:25<02:56, 4.97 it/s]129/1000 00:25<02:52, 5.06 it/s]131/1000 00:25<02:50, 5.12 it/s]134/1000 00:26<02:46, 5.21 it/s]136/1000 00:26<02:44, 5.26 it/s]138/1000 00:26<02:42, 5.32 it/s]140/1000 00:26<02:40, 5.37 it/s]146/1000 00:26<02:34, 5.55 it/s]┫ 153/1000 00:28<02:36, 5.43 it/s]159/1000 00:28<02:30, 5.60 it/s]163/1000 00:28<02:27, 5.70 it/s]167/1000 00:29<02:24, 5.79 it/s]173/1000 00:29<02:20, 5.93 it/s]175/1000 00:29<02:18, 5.97 it/s]18.80%┣███████████▎ ┫ 188/1000 00:30<02:09, 6.29 it/s]┫ 189/1000 00:30<02:09, 6.31 it/s]191/1000 00:30<02:07, 6.36 it/s]197/1000 00:30<02:04, 6.50 it/s]208/1000 00:33<02:06, 6.26 it/s]215/1000 00:33<02:03, 6.40 it/s]217/1000 00:34<02:01, 6.45 it/s]227/1000 00:34<01:56, 6.65 it/s]232/1000 00:34<01:54, 6.75 it/s]237/1000 00:35<01:52, 6.84 it/s]┫ 241/1000 00:35<01:50, 6.91 it/s]245/1000 00:35<01:48, 6.98 it/s]247/1000 00:35<01:47, 7.02 it/s]251/1000 00:37<01:52, 6.67 it/s]258/1000 00:38<01:49, 6.79 it/s]260/1000 00:38<01:49, 6.82 it/s]262/1000 00:38<01:48, 6.85 it/s]265/1000 00:38<01:46, 6.90 it/s]269/1000 00:38<01:45, 6.99 it/s]273/1000 00:39<01:43, 7.05 it/s]275/1000 00:39<01:42, 7.08 it/s]277/1000 00:39<01:42, 7.11 it/s]281/1000 00:39<01:40, 7.17 it/s]283/1000 00:39<01:40, 7.20 it/s]285/1000 00:39<01:39, 7.23 it/s]291/1000 00:40<01:37, 7.33 it/s]293/1000 00:40<01:36, 7.36 it/s]296/1000 00:40<01:35, 7.41 it/s]298/1000 00:40<01:34, 7.45 it/s]300/1000 00:42<01:38, 7.17 it/s]306/1000 00:42<01:36, 7.24 it/s]311/1000 00:42<01:34, 7.32 it/s]313/1000 00:42<01:34, 7.35 it/s]318/1000 00:43<01:32, 7.42 it/s]324/1000 00:43<01:30, 7.50 it/s]326/1000 00:43<01:30, 7.52 it/s]328/1000 00:43<01:29, 7.55 it/s]332/1000 00:43<01:28, 7.61 it/s]334/1000 00:44<01:27, 7.64 it/s]336/1000 00:44<01:27, 7.67 it/s]338/1000 00:44<01:26, 7.69 it/s]342/1000 00:44<01:25, 7.74 it/s]348/1000 00:44<01:23, 7.82 it/s]349/1000 00:46<01:26, 7.57 it/s]355/1000 00:46<01:25, 7.62 it/s]357/1000 00:47<01:24, 7.65 it/s]363/1000 00:47<01:22, 7.73 it/s]366/1000 00:47<01:21, 7.78 it/s]369/1000 00:47<01:21, 7.82 it/s]372/1000 00:47<01:20, 7.86 it/s]374/1000 00:47<01:19, 7.89 it/s]┫ 379/1000 00:48<01:18, 7.95 it/s]385/1000 00:48<01:17, 8.02 it/s]388/1000 00:48<01:16, 8.06 it/s]391/1000 00:48<01:15, 8.10 it/s]397/1000 00:48<01:14, 8.18 it/s]398/1000 00:48<01:13, 8.19 it/s]406/1000 00:51<01:14, 7.99 it/s]411/1000 00:51<01:13, 8.05 it/s]413/1000 00:51<01:13, 8.07 it/s]427/1000 00:52<01:10, 8.22 it/s]432/1000 00:52<01:09, 8.27 it/s]438/1000 00:52<01:07, 8.33 it/s]443/1000 00:53<01:06, 8.39 it/s]448/1000 00:53<01:05, 8.45 it/s]450/1000 00:55<01:07, 8.24 it/s]452/1000 00:55<01:06, 8.25 it/s]457/1000 00:55<01:05, 8.31 it/s]462/1000 00:55<01:04, 8.36 it/s]465/1000 00:55<01:04, 8.40 it/s]474/1000 00:56<01:02, 8.50 it/s]476/1000 00:56<01:01, 8.52 it/s]┫ 480/1000 00:56<01:01, 8.56 it/s]483/1000 00:56<01:00, 8.59 it/s]485/1000 00:56<01:00, 8.62 it/s]491/1000 00:56<00:59, 8.68 it/s]494/1000 00:57<00:58, 8.72 it/s]496/1000 00:57<00:58, 8.74 it/s]498/1000 00:57<00:57, 8.76 it/s]503/1000 00:59<00:58, 8.56 it/s]504/1000 00:59<00:58, 8.56 it/s]514/1000 01:00<00:56, 8.62 it/s]516/1000 01:00<00:56, 8.64 it/s]519/1000 01:00<00:55, 8.67 it/s]522/1000 01:00<00:55, 8.70 it/s]524/1000 01:00<00:55, 8.72 it/s]530/1000 01:00<00:53, 8.79 it/s]532/1000 01:00<00:53, 8.81 it/s]534/1000 01:00<00:53, 8.82 it/s]536/1000 01:01<00:53, 8.84 it/s]538/1000 01:01<00:52, 8.85 it/s]540/1000 01:01<00:52, 8.86 it/s]542/1000 01:01<00:52, 8.87 it/s]544/1000 01:01<00:51, 8.89 it/s]546/1000 01:01<00:51, 8.90 it/s]548/1000 01:01<00:51, 8.92 it/s]552/1000 01:04<00:52, 8.67 it/s]554/1000 01:04<00:51, 8.69 it/s]556/1000 01:04<00:51, 8.71 it/s]┫ 564/1000 01:04<00:50, 8.78 it/s]568/1000 01:04<00:49, 8.82 it/s]571/1000 01:04<00:48, 8.85 it/s]573/1000 01:04<00:48, 8.87 it/s]585/1000 01:05<00:46, 8.99 it/s]588/1000 01:05<00:46, 9.02 it/s]590/1000 01:05<00:45, 9.04 it/s]593/1000 01:05<00:45, 9.07 it/s]597/1000 01:05<00:44, 9.11 it/s]603/1000 01:07<00:44, 8.97 it/s]605/1000 01:07<00:44, 8.99 it/s]┫ 608/1000 01:07<00:43, 9.01 it/s]61.10%┣████████████████████████████████████▋ ┫ 611/1000 01:07<00:43, 9.04 it/s]616/1000 01:08<00:42, 9.07 it/s]618/1000 01:08<00:42, 9.08 it/s]628/1000 01:09<00:41, 9.14 it/s]630/1000 01:09<00:40, 9.16 it/s]632/1000 01:09<00:40, 9.17 it/s]636/1000 01:09<00:40, 9.20 it/s]638/1000 01:09<00:39, 9.21 it/s]647/1000 01:10<00:38, 9.28 it/s]650/1000 01:11<00:38, 9.12 it/s]652/1000 01:11<00:38, 9.13 it/s]654/1000 01:11<00:38, 9.14 it/s]662/1000 01:12<00:37, 9.18 it/s]668/1000 01:12<00:36, 9.22 it/s]670/1000 01:12<00:36, 9.23 it/s]672/1000 01:13<00:35, 9.24 it/s]678/1000 01:13<00:35, 9.29 it/s]680/1000 01:13<00:34, 9.30 it/s]683/1000 01:13<00:34, 9.32 it/s]686/1000 01:13<00:34, 9.35 it/s]689/1000 01:13<00:33, 9.37 it/s]692/1000 01:14<00:33, 9.40 it/s]696/1000 01:14<00:32, 9.42 it/s]698/1000 01:14<00:32, 9.43 it/s]701/1000 01:16<00:32, 9.22 it/s]703/1000 01:16<00:32, 9.24 it/s]712/1000 01:16<00:31, 9.30 it/s]714/1000 01:17<00:31, 9.31 it/s]716/1000 01:17<00:30, 9.32 it/s]720/1000 01:17<00:30, 9.34 it/s]732/1000 01:18<00:28, 9.43 it/s]734/1000 01:18<00:28, 9.44 it/s]736/1000 01:18<00:28, 9.45 it/s]739/1000 01:18<00:28, 9.47 it/s]747/1000 01:18<00:27, 9.52 it/s]748/1000 01:18<00:26, 9.52 it/s]751/1000 01:20<00:27, 9.35 it/s]754/1000 01:20<00:26, 9.38 it/s]758/1000 01:21<00:26, 9.40 it/s]760/1000 01:21<00:26, 9.41 it/s]773/1000 01:21<00:24, 9.49 it/s]775/1000 01:21<00:24, 9.50 it/s]787/1000 01:22<00:22, 9.58 it/s]790/1000 01:22<00:22, 9.60 it/s]802/1000 01:25<00:21, 9.46 it/s]806/1000 01:25<00:20, 9.49 it/s]812/1000 01:25<00:20, 9.53 it/s]┫ 815/1000 01:25<00:19, 9.55 it/s]822/1000 01:26<00:19, 9.60 it/s]▌ ┫ 827/1000 01:26<00:18, 9.63 it/s]829/1000 01:26<00:18, 9.64 it/s]831/1000 01:26<00:18, 9.65 it/s]843/1000 01:27<00:16, 9.73 it/s]848/1000 01:27<00:16, 9.77 it/s]853/1000 01:29<00:15, 9.53 it/s]861/1000 01:30<00:15, 9.58 it/s]▊ ┫ 864/1000 01:30<00:14, 9.60 it/s]866/1000 01:30<00:14, 9.61 it/s]870/1000 01:30<00:13, 9.63 it/s]872/1000 01:30<00:13, 9.64 it/s]878/1000 01:31<00:13, 9.67 it/s]880/1000 01:31<00:12, 9.68 it/s]884/1000 01:31<00:12, 9.70 it/s]886/1000 01:31<00:12, 9.71 it/s]889/1000 01:31<00:11, 9.73 it/s]891/1000 01:31<00:11, 9.74 it/s]894/1000 01:31<00:11, 9.76 it/s]898/1000 01:32<00:10, 9.78 it/s]904/1000 01:34<00:10, 9.65 it/s]906/1000 01:34<00:10, 9.66 it/s]914/1000 01:34<00:09, 9.70 it/s]916/1000 01:34<00:09, 9.71 it/s]921/1000 01:34<00:08, 9.74 it/s]934/1000 01:35<00:07, 9.83 it/s]939/1000 01:35<00:06, 9.85 it/s]940/1000 01:35<00:06, 9.85 it/s]941/1000 01:35<00:06, 9.85 it/s]942/1000 01:36<00:06, 9.84 it/s]944/1000 01:36<00:06, 9.84 it/s]948/1000 01:36<00:05, 9.86 it/s]954/1000 01:38<00:05, 9.72 it/s]956/1000 01:38<00:05, 9.73 it/s]962/1000 01:39<00:04, 9.75 it/s]968/1000 01:39<00:03, 9.79 it/s]970/1000 01:39<00:03, 9.80 it/s]974/1000 01:39<00:03, 9.82 it/s]981/1000 01:40<00:02, 9.85 it/s]989/1000 01:40<00:01, 9.89 it/s]┫ 996/1000 01:40<00:00, 9.93 it/s]
┌ Info: Saved animation to │ fn = C:\Users\chikuwa\workspace\julia-work\gs_dpgmm.gif └ @ Plots C:\Users\chikuwa\.julia\packages\Plots\Ih71u\src\animation.jl:95
plot(loglik_list, legend=:none, title="log likelihood", xlabel="iteration")
サンプリングの過程
結果
plt = plot()
plot_contours(gmm_max)
scatter!(x_train[1, :], x_train[2, :], markercolor=hidden_state_max, legend=:none)