クラスタ数は『続・わかりやすいパターン認識』の13章P.279のデータと揃えてある. ただし,関係行列の行数・列数は比率はそのままで数を10倍に増やしてる.
using Distributions
using Plots
using DelimitedFiles
行(顧客)と列(商品)のクラスタ数c1, c2
c1 = 4;
c2 = 3;
行(顧客)と列(商品)のデータ数K, L
K = 15 * 10;
L = 10 * 10;
行(顧客)と列(商品)の混合比率π1, π2
π1 = [4/15; 3/15; 5/15; 3/15];
π2 = [4/10; 3/10; 3/10];
行列Θ={θij}の各成分は, 顧客クラスタiに属する顧客が商品クラスタjに属する商品を購入する確率を表す.
Θ = [0.2 0.9 0.1;
1.0 0.8 0.0;
0.1 0.1 0.9;
0.2 0.7 0.1];
顧客と商品の所属クラスを表す潜在変数s1, s2
s1 = rand(Categorical(π1), K);
s2 = rand(Categorical(π2), L);
顧客kが商品lを購入してたら1,そうでなければ0を成分にもつ 関係行列R={Rkl}
R = zeros(K, L);
for k in 1:K
for l in 1:L
R[k, l] = rand(Bernoulli(Θ[s1[k], s2[l]]))
end
end
顧客と商品の関係行列Rをプロット
heatmap(R, yflip=true, c=ColorGradient([:white, :black]))
顧客と商品の関係行列Rを真のラベルを用いて並べ替えて表示する
row_idxs = sort(collect(1:K), by=i->s1[i]);
col_idxs = sort(collect(1:L), by=i->s2[i]);
heatmap(R[row_idxs, col_idxs], yflip=true, c=ColorGradient([:white, :black]))
writedlm("IRM_toydata_R_200402.csv", R, ',')
writedlm("IRM_toydata_s1_200402.csv", s1, ',')
writedlm("IRM_toydata_s2_200402.csv", s2, ',')
以下,出力したデータを読み込むテスト
R_gt = readdlm("IRM_toydata_R_200402.csv", ',');
s1_gt = readdlm("IRM_toydata_s1_200402.csv", ',');
s2_gt = readdlm("IRM_toydata_s2_200402.csv", ',');
row_idxs = sort(collect(1:K), by=i->s1_gt[i]);
col_idxs = sort(collect(1:L), by=i->s2_gt[i]);
heatmap(R_gt[row_idxs, col_idxs], yflip=true, c=ColorGradient([:white, :black]))