In [4]:
using Distributed
addprocs(2)
using Plots
@everywhere using DPMMSubClusters
In [7]:
dp = dp_parallel("params_2d.jl", verbose = false)
Saving Model:
  2.849497 seconds (8.23 M allocations: 404.699 MiB, 3.77% gc time)
Out[7]:
dp_parallel_sampling(DPMMSubClusters.model_hyper_params(DPMMSubClusters.niw_hyperparams(1.0, [0.0, 0.0], 5.0, [1.0 0.0; 0.0 1.0]), 100000.0, 1000), DPMMSubClusters.local_group(DPMMSubClusters.model_hyper_params(DPMMSubClusters.niw_hyperparams(1.0, [0.0, 0.0], 5.0, [1.0 0.0; 0.0 1.0]), 100000.0, 1000), [-2.98295 -2.29031 … -1.92896 -2.1365; -0.943438 1.17951 … 2.59362 2.23584], [1, 3, 2, 5, 5, 2, 5, 5, 2, 5  …  2, 4, 3, 4, 5, 4, 2, 1, 3, 3], [2, 2, 1, 1, 2, 2, 2, 1, 1, 1  …  2, 1, 1, 2, 2, 1, 1, 2, 1, 1], DPMMSubClusters.local_cluster[local_cluster(splittable_cluster_params(cluster_parameters(niw_hyperparams(1.0, [0.0, 0.0], 5.0, [1.0 0.0; 0.0 1.0]), mv_gaussian([-3.32295, -1.19545], [0.757923 -0.0832886; -0.0832886 0.572229], [1.34084 0.195161; 0.195161 1.77596], -0.851513), niw_sufficient_statistics(136.0, [-441.615, -154.467], [1504.24 484.289; 484.289 247.677]), niw_hyperparams(137.0, [-3.22346, -1.1275], 141.0, [0.60785 -0.0966756; -0.0966756 0.556849])), cluster_parameters(niw_hyperparams(1.0, [0.0, 0.0], 5.0, [1.0 0.0; 0.0 1.0]), mv_gaussian([-2.87837, -1.70357], [0.679831 0.150706; 0.150706 0.344743], [1.6288 -0.712038; -0.712038 3.21198], -1.5528), niw_sufficient_statistics(66.0, [-192.705, -112.739], [584.701 330.323; 330.323 210.97]), niw_hyperparams(67.0, [-2.8762, -1.68267], 71.0, [0.499198 0.085394; 0.085394 0.369961])), cluster_parameters(niw_hyperparams(1.0, [0.0, 0.0], 5.0, [1.0 0.0; 0.0 1.0]), mv_gaussian([-3.71637, -0.58372], [0.695401 0.0845759; 0.0845759 0.259263], [1.49743 -0.488487; -0.488487 4.01645], -1.75366), niw_sufficient_statistics(70.0, [-248.909, -41.7282], [919.535 153.966; 153.966 36.7071]), niw_hyperparams(71.0, [-3.50577, -0.587721], 75.0, [0.692217 0.102353; 0.102353 0.229101])), [0.499641, 0.500359], true, [-278.913, -272.184, -266.471, -272.155, -271.077]), 1000, 136, 492, 492), local_cluster(splittable_cluster_params(cluster_parameters(niw_hyperparams(1.0, [0.0, 0.0], 5.0, [1.0 0.0; 0.0 1.0]), mv_gaussian([5.55928, 6.01653], [0.702932 0.166039; 0.166039 0.660936], [1.51236 -0.379931; -0.379931 1.60845], -0.827766), niw_sufficient_statistics(182.0, [1027.11, 1106.36], [5883.25 6250.12; 6250.12 6824.78]), niw_hyperparams(183.0, [5.61262, 6.04568], 187.0, [0.660289 0.216838; 0.216838 0.754508])), cluster_parameters(niw_hyperparams(1.0, [0.0, 0.0], 5.0, [1.0 0.0; 0.0 1.0]), mv_gaussian([5.20073, 6.40861], [0.523723 0.381104; 0.381104 0.743753], [3.04467 -1.56011; -1.56011 2.14395], -1.40944), niw_sufficient_statistics(93.0, [486.428, 597.313], [2572.3 3142.96; 3142.96 3874.01]), niw_hyperparams(94.0, [5.17477, 6.35439], 98.0, [0.613746 0.530631; 0.530631 0.851496])), cluster_parameters(niw_hyperparams(1.0, [0.0, 0.0], 5.0, [1.0 0.0; 0.0 1.0]), mv_gaussian([5.9153, 5.67959], [0.550579 0.421439; 0.421439 0.789988], [3.06983 -1.63768; -1.63768 2.1395], -1.35736), niw_sufficient_statistics(89.0, [540.681, 509.046], [3310.95 3107.16; 3107.16 2950.77]), niw_hyperparams(90.0, [6.00757, 5.65606], 94.0, [0.720943 0.521641; 0.521641 0.814636])), [0.502809, 0.497191], true, [-415.229, -417.327, -416.44, -425.272, -412.699]), 1000, 182, 492, 492), local_cluster(splittable_cluster_params(cluster_parameters(niw_hyperparams(1.0, [0.0, 0.0], 5.0, [1.0 0.0; 0.0 1.0]), mv_gaussian([-1.86981, 1.77523], [0.490766 -0.129436; -0.129436 0.578601], [2.16539 0.484409; 0.484409 1.83667], -1.31974), niw_sufficient_statistics(143.0, [-268.068, 251.318], [577.546 -479.714; -479.714 510.905]), niw_hyperparams(144.0, [-1.86158, 1.74526], 148.0, [0.564288 -0.0801716; -0.0801716 0.522233])), cluster_parameters(niw_hyperparams(1.0, [0.0, 0.0], 5.0, [1.0 0.0; 0.0 1.0]), mv_gaussian([-1.41538, 2.10192], [0.471225 -0.188001; -0.188001 0.362744], [2.67531 1.38654; 1.38654 3.47538], -1.99812), niw_sufficient_statistics(64.0, [-94.7127, 142.126], [163.014 -220.614; -220.614 332.674]), niw_hyperparams(65.0, [-1.45712, 2.18655], 69.0, [0.434868 -0.195943; -0.195943 0.389995])), cluster_parameters(niw_hyperparams(1.0, [0.0, 0.0], 5.0, [1.0 0.0; 0.0 1.0]), mv_gaussian([-2.30155, 1.44522], [0.503754 -0.210701; -0.210701 0.358906], [2.63117 1.54467; 1.54467 3.69307], -1.99213), niw_sufficient_statistics(79.0, [-173.355, 109.192], [414.532 -259.1; -259.1 178.231]), niw_hyperparams(80.0, [-2.16694, 1.3649], 84.0, [0.522397 -0.267708; -0.267708 0.407084])), [0.500631, 0.499369], true, [-264.935, -262.598, -263.375, -265.967, -264.688]), 1000, 143, 492, 492), local_cluster(splittable_cluster_params(cluster_parameters(niw_hyperparams(1.0, [0.0, 0.0], 5.0, [1.0 0.0; 0.0 1.0]), mv_gaussian([4.34657, -0.36048], [0.53338 0.0445443; 0.0445443 0.92721], [1.88239 -0.0904321; -0.0904321 1.08285], -0.708116), niw_sufficient_statistics(381.0, [1676.5, -149.368], [7570.09 -641.005; -641.005 366.002]), niw_hyperparams(382.0, [4.38875, -0.391015], 386.0, [0.563018 0.0376523; 0.0376523 0.809836])), cluster_parameters(niw_hyperparams(1.0, [0.0, 0.0], 5.0, [1.0 0.0; 0.0 1.0]), mv_gaussian([4.24818, -1.05487], [0.502046 0.00576512; 0.00576512 0.415919], [1.99217 -0.0276137; -0.0276137 2.4047], -1.56649), niw_sufficient_statistics(190.0, [819.287, -205.721], [3620.66 -879.885; -879.885 289.247]), niw_hyperparams(191.0, [4.28946, -1.07707], 195.0, [0.571087 0.0130663; 0.0130663 0.372667])), cluster_parameters(niw_hyperparams(1.0, [0.0, 0.0], 5.0, [1.0 0.0; 0.0 1.0]), mv_gaussian([4.53533, 0.345487], [0.654089 -0.0429533; -0.0429533 0.31315], [1.54274 0.21161; 0.21161 3.22238], -1.59463), niw_sufficient_statistics(191.0, [857.218, 56.3534], [3949.43 238.88; 238.88 76.7544]), niw_hyperparams(192.0, [4.46468, 0.293508], 196.0, [0.649154 -0.0648968; -0.0648968 0.332726])), [0.498204, 0.501796], true, [-806.362, -810.834, -802.615, -805.98, -800.645]), 1000, 381, 492, 492), local_cluster(splittable_cluster_params(cluster_parameters(niw_hyperparams(1.0, [0.0, 0.0], 5.0, [1.0 0.0; 0.0 1.0]), mv_gaussian([-0.395949, -2.28661], [0.458608 -0.0165207; -0.0165207 0.540331], [2.18292 0.0667431; 0.0667431 1.85276], -1.39624), niw_sufficient_statistics(158.0, [-71.4696, -350.785], [103.872 165.506; 165.506 852.233]), niw_hyperparams(159.0, [-0.449494, -2.2062], 163.0, [0.470839 0.048036; 0.048036 0.51123])), cluster_parameters(niw_hyperparams(1.0, [0.0, 0.0], 5.0, [1.0 0.0; 0.0 1.0]), mv_gaussian([-1.00192, -2.33012], [0.212153 -0.0549705; -0.0549705 0.618071], [4.82477 0.429109; 0.429109 1.6561], -2.05491), niw_sufficient_statistics(82.0, [-78.8377, -194.871], [87.237 183.521; 183.521 505.479]), niw_hyperparams(83.0, [-0.949852, -2.34784], 87.0, [0.199458 -0.0181339; -0.0181339 0.608675])), cluster_parameters(niw_hyperparams(1.0, [0.0, 0.0], 5.0, [1.0 0.0; 0.0 1.0]), mv_gaussian([0.123024, -2.05159], [0.24601 -0.0496297; -0.0496297 0.542399], [4.14132 0.378932; 0.378932 1.87833], -2.03277), niw_sufficient_statistics(76.0, [7.36814, -155.915], [16.635 -18.0147; -18.0147 346.753]), niw_hyperparams(77.0, [0.0956901, -2.02487], 81.0, [0.258394 -0.038212; -0.038212 0.445014])), [0.501036, 0.498964], true, [-292.034, -295.041, -288.85, -289.42, -295.046]), 1000, 158, 492, 492)], [0.00131345, 0.00185653, 0.00153941, 0.00356966, 0.00159244]))
In [8]:
dp = run_model_from_checkpoint("checkpoint__50.jld2")
Loading Model:
  1.073261 seconds (2.27 M allocations: 113.221 MiB, 2.60% gc time)
Including params
Loading data:
  0.000881 seconds (10.02 k allocations: 378.313 KiB)
Creating model:
Node Leaders:
Dict{Any,Any}(2=>Any[2, 3])
Running model:
Out[8]:
dp_parallel_sampling(DPMMSubClusters.model_hyper_params(DPMMSubClusters.niw_hyperparams(1.0, [0.0, 0.0], 5.0, [1.0 0.0; 0.0 1.0]), 100000.0, 1000), DPMMSubClusters.local_group(DPMMSubClusters.model_hyper_params(DPMMSubClusters.niw_hyperparams(1.0, [0.0, 0.0], 5.0, [1.0 0.0; 0.0 1.0]), 100000.0, 1000), [-2.98295 -2.29031 … -1.92896 -2.1365; -0.943438 1.17951 … 2.59362 2.23584], [1, 3, 2, 5, 5, 2, 5, 5, 2, 5  …  2, 4, 3, 4, 5, 4, 2, 1, 3, 3], [2, 2, 2, 1, 2, 2, 2, 1, 1, 1  …  2, 1, 1, 2, 2, 2, 2, 2, 1, 1], DPMMSubClusters.local_cluster[local_cluster(splittable_cluster_params(cluster_parameters(niw_hyperparams(1.0, [0.0, 0.0], 5.0, [1.0 0.0; 0.0 1.0]), mv_gaussian([-3.15739, -1.1828], [0.865352 -0.260615; -0.260615 0.546059], [1.34958 0.644107; 0.644107 2.13871], -0.904823), niw_sufficient_statistics(136.0, [-441.615, -154.467], [1504.24 484.289; 484.289 247.677]), niw_hyperparams(137.0, [-3.22346, -1.1275], 141.0, [0.60785 -0.0966756; -0.0966756 0.556849])), cluster_parameters(niw_hyperparams(1.0, [0.0, 0.0], 5.0, [1.0 0.0; 0.0 1.0]), mv_gaussian([-2.97523, -1.49102], [0.548311 0.0417648; 0.0417648 0.327489], [1.84167 -0.234869; -0.234869 3.08349], -1.72698), niw_sufficient_statistics(73.0, [-216.194, -120.227], [664.872 356.112; 356.112 219.425]), niw_hyperparams(74.0, [-2.92154, -1.62468], 78.0, [0.490392 0.0623651; 0.0623651 0.373011])), cluster_parameters(niw_hyperparams(1.0, [0.0, 0.0], 5.0, [1.0 0.0; 0.0 1.0]), mv_gaussian([-3.56913, -0.533871], [0.632689 0.0954314; 0.0954314 0.192724], [1.70813 -0.845817; -0.845817 5.60758], -2.18189), niw_sufficient_statistics(63.0, [-225.42, -34.2408], [839.364 128.177; 128.177 28.2526]), niw_hyperparams(64.0, [-3.52219, -0.535013], 68.0, [0.741032 0.111388; 0.111388 0.219608])), [0.502684, 0.497316], true, [-267.679, -273.23, -278.76, -270.026, -268.595]), 1000, 136, 492, 492), local_cluster(splittable_cluster_params(cluster_parameters(niw_hyperparams(1.0, [0.0, 0.0], 5.0, [1.0 0.0; 0.0 1.0]), mv_gaussian([5.53265, 6.03432], [0.767141 0.240894; 0.240894 0.740073], [1.45195 -0.472611; -0.472611 1.50505], -0.673914), niw_sufficient_statistics(182.0, [1027.11, 1106.36], [5883.25 6250.12; 6250.12 6824.78]), niw_hyperparams(183.0, [5.61262, 6.04568], 187.0, [0.660289 0.216838; 0.216838 0.754508])), cluster_parameters(niw_hyperparams(1.0, [0.0, 0.0], 5.0, [1.0 0.0; 0.0 1.0]), mv_gaussian([5.16593, 6.3178], [0.638812 0.661939; 0.661939 1.15022], [3.87786 -2.23166; -2.23166 2.15369], -1.21533), niw_sufficient_statistics(86.0, [446.471, 553.546], [2344.31 2893.39; 2893.39 3600.31]), niw_hyperparams(87.0, [5.13185, 6.3626], 91.0, [0.638397 0.578915; 0.578915 0.915633])), cluster_parameters(niw_hyperparams(1.0, [0.0, 0.0], 5.0, [1.0 0.0; 0.0 1.0]), mv_gaussian([5.99587, 5.65111], [0.699121 0.479286; 0.479286 0.798382], [2.43075 -1.45923; -1.45923 2.12854], -1.11337), niw_sufficient_statistics(96.0, [580.638, 552.813], [3538.93 3356.72; 3356.72 3224.47]), niw_hyperparams(97.0, [5.98596, 5.6991], 101.0, [0.675759 0.471351; 0.471351 0.781497])), [0.499795, 0.500205], true, [-414.586, -421.263, -419.293, -416.832, -411.363]), 1000, 182, 492, 492), local_cluster(splittable_cluster_params(cluster_parameters(niw_hyperparams(1.0, [0.0, 0.0], 5.0, [1.0 0.0; 0.0 1.0]), mv_gaussian([-1.88092, 1.79758], [0.511821 -0.0225497; -0.0225497 0.502473], [1.95768 0.0878555; 0.0878555 1.9941], -1.35997), niw_sufficient_statistics(142.0, [-267.335, 251.526], [577.009 -479.867; -479.867 510.862]), niw_hyperparams(143.0, [-1.86948, 1.75892], 147.0, [0.559397 -0.0656098; -0.0656098 0.499643])), cluster_parameters(niw_hyperparams(1.0, [0.0, 0.0], 5.0, [1.0 0.0; 0.0 1.0]), mv_gaussian([-1.50945, 2.39099], [0.447635 -0.25529; -0.25529 0.421086], [3.41459 2.07015; 2.07015 3.62988], -2.09298), niw_sufficient_statistics(59.0, [-85.7616, 133.647], [146.055 -204.721; -204.721 317.659]), niw_hyperparams(60.0, [-1.42936, 2.22745], 64.0, [0.444853 -0.213921; -0.213921 0.390096])), cluster_parameters(niw_hyperparams(1.0, [0.0, 0.0], 5.0, [1.0 0.0; 0.0 1.0]), mv_gaussian([-2.1953, 1.41181], [0.411573 -0.230743; -0.230743 0.363518], [3.77203 2.3943; 2.3943 4.27068], -2.33954), niw_sufficient_statistics(83.0, [-181.574, 117.879], [430.954 -275.146; -275.146 193.203]), niw_hyperparams(84.0, [-2.16159, 1.40332], 88.0, [0.493928 -0.231138; -0.231138 0.372522])), [0.502065, 0.497935], true, [-270.602, -268.809, -265.769, -265.313, -268.296]), 1000, 142, 492, 492), local_cluster(splittable_cluster_params(cluster_parameters(niw_hyperparams(1.0, [0.0, 0.0], 5.0, [1.0 0.0; 0.0 1.0]), mv_gaussian([4.3953, -0.467719], [0.629589 -0.00111722; -0.00111722 0.730138], [1.58834 0.0024304; 0.0024304 1.36961], -0.777213), niw_sufficient_statistics(380.0, [1674.16, -145.661], [7564.61 -632.324; -632.324 352.26]), niw_hyperparams(381.0, [4.39413, -0.382312], 385.0, [0.553568 0.0200729; 0.0200729 0.783304])), cluster_parameters(niw_hyperparams(1.0, [0.0, 0.0], 5.0, [1.0 0.0; 0.0 1.0]), mv_gaussian([4.36841, -1.17706], [0.656383 0.0235926; 0.0235926 0.356892], [1.52713 -0.100952; -0.100952 2.80864], -1.45371), niw_sufficient_statistics(176.0, [760.934, -195.964], [3373.1 -845.509; -845.509 271.969]), niw_hyperparams(177.0, [4.29906, -1.10714], 181.0, [0.59 -0.0168378; -0.0168378 0.331544])), cluster_parameters(niw_hyperparams(1.0, [0.0, 0.0], 5.0, [1.0 0.0; 0.0 1.0]), mv_gaussian([4.51891, 0.224571], [0.596676 -0.0722436; -0.0722436 0.373002], [1.7162 0.332395; 0.332395 2.74533], -1.52628), niw_sufficient_statistics(204.0, [913.228, 50.3031], [4191.51 213.185; 213.185 80.2907]), niw_hyperparams(205.0, [4.45477, 0.245381], 209.0, [0.61382 -0.0521718; -0.0521718 0.34903])), [0.501926, 0.498074], true, [-792.46, -795.821, -782.109, -799.746, -795.891]), 1000, 380, 492, 492), local_cluster(splittable_cluster_params(cluster_parameters(niw_hyperparams(1.0, [0.0, 0.0], 5.0, [1.0 0.0; 0.0 1.0]), mv_gaussian([-0.61038, -2.27563], [0.693892 0.0860641; 0.0860641 0.56745], [1.46878 -0.222767; -0.222767 1.79606], -0.951033), niw_sufficient_statistics(160.0, [-69.8605, -354.701], [109.893 156.978; 156.978 866.018]), niw_hyperparams(161.0, [-0.433916, -2.20311], 165.0, [0.5126 0.0185889; 0.0185889 0.542868])), cluster_parameters(niw_hyperparams(1.0, [0.0, 0.0], 5.0, [1.0 0.0; 0.0 1.0]), mv_gaussian([-0.923207, -2.02815], [0.196433 0.0213241; 0.0213241 0.532829], [5.113 -0.204625; -0.204625 1.88496], -2.26134), niw_sufficient_statistics(80.0, [-76.8799, -172.652], [85.4577 170.805; 170.805 412.726]), niw_hyperparams(81.0, [-0.949135, -2.1315], 85.0, [0.205745 0.08159; 0.08159 0.584914])), cluster_parameters(niw_hyperparams(1.0, [0.0, 0.0], 5.0, [1.0 0.0; 0.0 1.0]), mv_gaussian([0.0368826, -2.2552], [0.381483 0.164572; 0.164572 0.759517], [2.89165 -0.626562; -0.626562 1.45239], -1.3369), niw_sufficient_statistics(80.0, [7.01942, -182.049], [24.4348 -13.8275; -13.8275 453.292]), niw_hyperparams(81.0, [0.0866595, -2.24752], 85.0, [0.339135 0.0229267; 0.0229267 0.578049])), [0.500527, 0.499473], true, [-294.526, -291.831, -310.561, -295.105, -291.637]), 1000, 160, 492, 492)], [0.00148355, 0.00176785, 0.00123266, 0.0038411, 0.00146099]))
In [9]:
labels = Array(dp[1].group.labels)
pts = Array(dp[1].group.points)
plt=Plots.plot()
Plots.plot!(pts[1,:],pts[2,:], seriestype=:scatter, color = labels, markersize = 3, markerstrokewidth = 0.5)
Out[9]:
-6 -4 -2 0 2 4 6 -2.5 0.0 2.5 5.0 7.5 y1
In [ ]: