In [1]:
import sklearn.mixture
import numpy as np
import scipy as sp
from keras import Input
from keras.layers   import Dense, Lambda, Layer, Activation
from keras.models   import Model
from keras.layers.normalization import BatchNormalization
import keras.backend as K
import matplotlib.pyplot as plt
import autograd.numpy.random as npr
import pylab
from keras                   import backend
from sklearn.metrics         import adjusted_rand_score
%matplotlib inline
from sklearn.metrics import accuracy_score
/Users/tomohiromimura/.pyenv/versions/3.6.4/envs/normal/lib/python3.6/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.
  from ._conv import register_converters as _register_converters
Using TensorFlow backend.
In [2]:
def make_pinwheel_data(radial_std, tangential_std, num_classes, num_per_class, rate):
    rads = np.linspace(0, 2*np.pi, num_classes, endpoint=False)

    features = npr.randn(num_classes*num_per_class, 2) \
        * np.array([radial_std, tangential_std])
    features[:,0] += 1.
    labels    = np.repeat(np.arange(num_classes), num_per_class)

    angles    = rads[labels] + rate * np.exp(features[:,0])
    rotations = np.stack([np.cos(angles), -np.sin(angles), np.sin(angles), np.cos(angles)])
    rotations = np.reshape(rotations.T, (-1, 2, 2))

    return 10*np.einsum('ti,tij->tj', features, rotations),labels
In [3]:
num_clusters        = 5      # number of clusters in pinwheel data
samples_per_cluster = 200    # number of samples per cluster in pinwheel
K = 15                       # number of components in mixture model
N = 2                        # number of latent dimensions
P = 2                        # number of observation dimensions
batch_size          = num_clusters * samples_per_cluster
data, label = make_pinwheel_data(0.3, 0.05, num_clusters, samples_per_cluster, 0.25)

plt.figure(figsize=(10,10))
pylab.scatter(data[:,0],data[:,1],c=label)
Out[3]:
<matplotlib.collections.PathCollection at 0x11b2cbfd0>
In [4]:
def vae_loss(y_true, y_pred):
    # 入力と出力の交差エントロピー
    xent_loss =   backend.sum((backend.square(y_true - y_mean))/(backend.exp(y_log_var)) + backend.log(np.pi * 2) + y_log_var, axis=-1) * 0.5
    # 事前分布と事後分布のKL情報量
    kl_loss   =   backend.sum((backend.square(z - z_mean))/(backend.exp(z_log_var)) + backend.log(np.pi * 2) + z_log_var, axis=-1) * 0.5
    return backend.mean(xent_loss  - kl_loss - gmm_loss)


def sampling(args):
    mean, log_var = args
    epsilon = backend.random_normal(shape=backend.shape(mean), mean=0.,stddev=1.)
    return mean + backend.exp(log_var / 2) * epsilon
In [5]:
original_dim      = data.shape[1]
intermediate_dim  = 100
latent_dim        = 2

x                 = Input(shape=(original_dim,))
gmm_loss          = Input(shape=(1,))

decoder           = Dense(intermediate_dim)(x)
decoder           = BatchNormalization()(decoder)
decoder           = Activation('tanh')(decoder)



z_mean            = Dense(latent_dim,activation='linear')(decoder)
z_log_var         = Dense(latent_dim,activation='linear')(decoder)

z                 = Lambda(sampling,   output_shape=(latent_dim,))([z_mean, z_log_var])

decoder           = Dense(intermediate_dim)(z)
decoder           = BatchNormalization()(decoder)
decoder           = Activation('tanh')(decoder)


y_mean            = Dense(original_dim,activation='linear')(decoder)
y_log_var         = Dense(original_dim,activation='linear')(decoder)

y                 = Lambda(sampling, output_shape=(original_dim,))([y_mean, y_log_var])

VAE               = Model(input=[x,gmm_loss], output=y)
VAE_latent        = Model(input=x, output=z)

VAE_latent_mean   = Model(input=x, output=z_mean)
VAE_latent_var    = Model(input=x, output=z_log_var)
VAE_predict_mean  = Model(input=x, output=y_mean)
VAE_predict_var   = Model(input=x, output=y_log_var)

VAE.compile(optimizer='adam', loss=vae_loss)
VAE.summary()
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            (None, 2)            0                                            
__________________________________________________________________________________________________
dense_1 (Dense)                 (None, 100)          300         input_1[0][0]                    
__________________________________________________________________________________________________
batch_normalization_1 (BatchNor (None, 100)          400         dense_1[0][0]                    
__________________________________________________________________________________________________
activation_1 (Activation)       (None, 100)          0           batch_normalization_1[0][0]      
__________________________________________________________________________________________________
dense_2 (Dense)                 (None, 2)            202         activation_1[0][0]               
__________________________________________________________________________________________________
dense_3 (Dense)                 (None, 2)            202         activation_1[0][0]               
__________________________________________________________________________________________________
lambda_1 (Lambda)               (None, 2)            0           dense_2[0][0]                    
                                                                 dense_3[0][0]                    
__________________________________________________________________________________________________
dense_4 (Dense)                 (None, 100)          300         lambda_1[0][0]                   
__________________________________________________________________________________________________
batch_normalization_2 (BatchNor (None, 100)          400         dense_4[0][0]                    
__________________________________________________________________________________________________
activation_2 (Activation)       (None, 100)          0           batch_normalization_2[0][0]      
__________________________________________________________________________________________________
dense_5 (Dense)                 (None, 2)            202         activation_2[0][0]               
__________________________________________________________________________________________________
dense_6 (Dense)                 (None, 2)            202         activation_2[0][0]               
__________________________________________________________________________________________________
lambda_2 (Lambda)               (None, 2)            0           dense_5[0][0]                    
                                                                 dense_6[0][0]                    
==================================================================================================
Total params: 2,208
Trainable params: 1,808
Non-trainable params: 400
__________________________________________________________________________________________________
/Users/tomohiromimura/.pyenv/versions/3.6.4/envs/normal/lib/python3.6/site-packages/ipykernel_launcher.py:29: UserWarning: Update your `Model` call to the Keras 2 API: `Model(inputs=[<tf.Tenso..., outputs=Tensor("la...)`
/Users/tomohiromimura/.pyenv/versions/3.6.4/envs/normal/lib/python3.6/site-packages/ipykernel_launcher.py:30: UserWarning: Update your `Model` call to the Keras 2 API: `Model(inputs=Tensor("in..., outputs=Tensor("la...)`
/Users/tomohiromimura/.pyenv/versions/3.6.4/envs/normal/lib/python3.6/site-packages/ipykernel_launcher.py:32: UserWarning: Update your `Model` call to the Keras 2 API: `Model(inputs=Tensor("in..., outputs=Tensor("de...)`
/Users/tomohiromimura/.pyenv/versions/3.6.4/envs/normal/lib/python3.6/site-packages/ipykernel_launcher.py:33: UserWarning: Update your `Model` call to the Keras 2 API: `Model(inputs=Tensor("in..., outputs=Tensor("de...)`
/Users/tomohiromimura/.pyenv/versions/3.6.4/envs/normal/lib/python3.6/site-packages/ipykernel_launcher.py:34: UserWarning: Update your `Model` call to the Keras 2 API: `Model(inputs=Tensor("in..., outputs=Tensor("de...)`
/Users/tomohiromimura/.pyenv/versions/3.6.4/envs/normal/lib/python3.6/site-packages/ipykernel_launcher.py:35: UserWarning: Update your `Model` call to the Keras 2 API: `Model(inputs=Tensor("in..., outputs=Tensor("de...)`
In [10]:
from scipy.stats import multivariate_normal
epoch  = 500
latent = VAE_latent.predict(data)
gmm    = sklearn.mixture.GaussianMixture(n_components=num_clusters, covariance_type='full', max_iter=1)
for i in range(1):
    gmm.fit(latent)
    
pi_i   = gmm.predict_proba(latent)
mu_    = gmm.means_
sigma_ = gmm.covariances_

liklyhood = np.zeros(data.shape[0])

for m,s,pi in zip(mu_,sigma_,pi_i.T):
    liklyhood += pi*multivariate_normal.logpdf(latent,m,s)
    
for i in range(epoch):
    loss   = VAE.train_on_batch([data,liklyhood],data)
    latent = VAE_latent.predict(data)
    gmm.fit(latent)
    liklyhood = np.zeros(data.shape[0])
    for m,s,pi in zip(mu_,sigma_,pi_i.T):
        liklyhood += pi*multivariate_normal.logpdf(latent,m,s)
        
    print(i,loss,adjusted_rand_score(gmm.predict(latent),label),liklyhood.mean())

    
/Users/tomohiromimura/.pyenv/versions/3.6.4/envs/normal/lib/python3.6/site-packages/sklearn/mixture/base.py:237: ConvergenceWarning: Initialization 1 did not converge. Try different init parameters, or increase max_iter, tol or check for degenerate data.
  % (init + 1), ConvergenceWarning)
0 103.41993 0.14651976688765464 -4.642745632400823
1 95.59113 0.12824113232861528 -4.768401059745967
2 87.78585 0.1427903809605796 -5.087102852087367
3 79.762024 0.12900684280492689 -5.095885069702208
4 73.61033 0.14716991183681807 -5.224178837160909
5 69.6406 0.1516915803066779 -5.321435847776818
6 63.365334 0.17547067308204586 -5.448860569112303
7 61.06198 0.15748745211492618 -5.547109175062757
8 57.81444 0.16493515084376786 -5.875019881697466
9 55.200344 0.17322788825720362 -5.922366681168228
10 53.785786 0.1994689807213474 -6.31613884267172
11 52.853863 0.20554879153442485 -6.427000457735698
12 50.0747 0.19929126174245684 -6.586035789843281
13 49.4235 0.2444029770832334 -7.113938276925496
14 48.939808 0.20281728465848078 -7.0752181247013874
15 48.175327 0.2479702379579886 -7.197497229914569
16 47.22448 0.24044023733156006 -7.74776051846974
17 47.204098 0.23132759000805986 -7.727132936996007
18 46.164593 0.25417393033220814 -8.361216120912708
19 46.21272 0.2623205017144226 -8.257024025246741
20 45.360035 0.2679299297823144 -8.432337060020489
21 45.0898 0.26341299480972885 -8.955047223624115
22 45.183117 0.2847788815989396 -8.986750684028825
23 44.341305 0.304357909891215 -9.118678806610129
24 44.023094 0.3049789727679779 -9.133409513398702
25 43.354504 0.298181801252176 -9.771252349173988
26 43.328617 0.3313741968663225 -9.46552848005264
27 42.637848 0.2933899887390654 -10.206163685103087
28 42.854668 0.2824323410859363 -10.098439694641483
29 42.127445 0.3350364692008074 -10.074408098165582
30 41.744812 0.3441112057140879 -10.648393606124756
31 41.759315 0.31496667977320336 -10.712282032424996
32 41.30469 0.29374557885406877 -10.91398448299344
33 41.093445 0.35681249754046634 -10.757477012351181
34 40.302723 0.3352114411684572 -10.701082358232117
35 39.955578 0.3217871236958889 -11.22482994338767
36 39.99076 0.3064558482089871 -11.237795999159438
37 39.449543 0.3254595325436984 -11.367971451172824
38 39.341152 0.33066540446604475 -11.776637032519774
39 39.379684 0.3520454813402633 -11.470247045814979
40 38.41451 0.35313500056111435 -11.808370536369395
41 38.31462 0.312527838276921 -11.839242458480676
42 37.946274 0.3995640222961814 -12.105170865732658
43 37.631298 0.357406945164926 -11.976346107920254
44 37.019817 0.37389684570629234 -12.708690839465294
45 37.623817 0.3768489766593488 -12.691160643150921
46 36.988323 0.3721888606848408 -13.258830140540232
47 37.19482 0.38520590966352464 -12.850440961070111
48 36.291996 0.4304298830365427 -12.95546218889975
49 36.04172 0.40958138331844146 -12.902951721902774
50 35.56395 0.40012647428659726 -13.352228439009615
51 35.744576 0.43068868984081443 -13.31682239217566
52 35.11469 0.43498224653250067 -13.392760720652564
53 34.971016 0.46538327211209746 -13.543084364713444
54 34.81939 0.45541808830069896 -14.209355551856541
55 35.005676 0.45952417621804315 -14.474968799058402
56 34.911633 0.4587407957264953 -14.402441757502897
57 34.540745 0.4817640966533137 -14.555936759869509
58 34.306637 0.5051612867073974 -14.775263593105267
59 34.190426 0.5279150869090891 -14.726068002626125
60 33.74902 0.49095220499003067 -15.274093792006092
61 33.892635 0.4503849097359772 -15.233248190956688
62 33.461246 0.592579195162078 -15.18504666062564
63 33.273697 0.5311190502341395 -15.230914296184357
64 33.029484 0.616360764946402 -15.587959826136379
65 32.88036 0.6172733434788995 -15.707476731163977
66 32.601303 0.5739292095600497 -16.11516390670334
67 32.833195 0.7025179919737342 -16.136710411741138
68 32.41162 0.6681456464802271 -16.359390852426053
69 32.217342 0.6572200271280586 -16.506623499510297
70 32.15791 0.616570248076777 -16.69191646860091
71 31.975357 0.6508899274075721 -16.767368705439125
72 31.919617 0.6881436464011886 -17.198648628042072
73 32.01978 0.6750837457359226 -17.15287093219415
74 31.779526 0.6918735806427492 -17.535223076142508
75 31.734312 0.7292373063569662 -17.546839876061203
76 31.394787 0.7247484048027135 -17.85347662678195
77 31.499252 0.6853985535919548 -18.16091017687199
78 31.508116 0.7349880420462164 -18.114569214957857
79 31.26149 0.7423769075505395 -18.779482900153415
80 31.529526 0.7470846068566828 -18.810700891083854
81 31.366941 0.7224087163678539 -19.358948249874157
82 31.610304 0.7912928349563398 -19.049027507685462
83 31.25202 0.7860326418541864 -18.9841620307431
84 30.79375 0.7794879457044368 -18.918911108164814
85 30.554947 0.7729100812809085 -19.20043552859653
86 30.642025 0.751684938920648 -19.32328037807464
87 30.525583 0.7823977889876562 -19.54307237088586
88 30.475044 0.7733421621124993 -19.655804229592643
89 30.335756 0.6100574237272759 -19.19181382693341
90 29.599869 0.7911275556904153 -20.235269492411447
91 30.501158 0.7853013833657153 -20.48729988680503
92 30.608463 0.8255076217483036 -20.841053862008533
93 30.705582 0.8074868199114797 -20.680699265385524
94 30.211409 0.8114180020669893 -20.277373519202893
95 29.582516 0.8259621798364055 -21.12387102984487
96 30.403755 0.8208680269386627 -20.869931041043717
97 29.891912 0.8096274224133904 -21.403154765954874
98 30.297428 0.8257224119363835 -20.950070943038387
99 29.704586 0.8298270017033267 -21.75907162496246
100 30.329098 0.8226775969883754 -21.33020752691789
101 29.64697 0.8180399408140512 -21.685804283846753
102 29.933577 0.8410587941716282 -21.7832938220972
103 29.77149 0.8174365149883811 -21.914447368881238
104 29.740046 0.8316926111696454 -22.0562411135716
105 29.806938 0.8317390103536481 -22.616170894713623
106 30.21606 0.8257662151507682 -22.578665375031882
107 29.998287 0.831902781528124 -22.63572117770363
108 29.980158 0.8567000423284565 -22.578966722635645
109 29.694403 0.8559573638817676 -22.96378858344357
110 30.065737 0.8146205189155667 -23.538564052303837
111 30.44531 0.8504772518209407 -23.334574091034835
112 30.1668 0.8499100301191843 -23.355271732430868
113 30.00439 0.8816800992327803 -23.82643283273959
114 30.351254 0.8704435257268806 -24.379300469557332
115 30.837906 0.8654744408674866 -23.992831492809703
116 30.281492 0.8771201441331563 -24.30258774625447
117 30.433645 0.8430568413968007 -24.193753794898146
118 30.190172 0.8819701597824047 -24.206014705524403
119 30.08351 0.8984660804866692 -24.318340916619256
120 30.183401 0.6292624939866924 -24.424072052119243
121 30.207083 0.8661644136773614 -24.425884975480553
122 30.132326 0.8439893360166838 -25.03229175762851
123 30.647522 0.8543628340837299 -24.696700128585203
124 30.135061 0.8886634471791776 -24.740907437579676
125 30.14606 0.8544594782887868 -25.12309035206863
126 30.399416 0.8769993662081029 -25.366146961600275
127 30.585445 0.8728957045242338 -25.835406509532216
128 30.914303 0.8774428347002621 -25.83783393471013
129 30.81512 0.8816079733718615 -25.850730226330057
130 30.847914 0.872977576349313 -26.079237589488496
131 30.905712 0.8861396547623495 -25.77705368727466
132 30.563902 0.8704937049471267 -26.14619757126324
133 30.812027 0.8750488408319063 -26.056442470836032
134 30.669239 0.8698272815488811 -26.029308435336134
135 30.626171 0.8914968532285611 -27.266426869315183
136 31.820328 0.8844745263033172 -26.92858053027846
137 31.370865 0.8754189772700823 -26.314381623393917
138 30.68016 0.886656686205439 -26.61959495337881
139 30.959408 0.9003565107749821 -27.275590168946533
140 31.447756 0.8912978112471132 -26.974919988135994
141 31.060806 0.8770648436681338 -27.315778715837606
142 31.502014 0.9145431755694918 -27.756548926950497
143 31.786802 0.8932503267172457 -27.151328056172446
144 31.135513 0.9056926942914326 -27.302334512345414
145 31.227976 0.8798300593183281 -27.385711123808193
146 31.358395 0.8796832637364902 -28.417603782630735
147 32.316177 0.6777675686638298 -27.51396937259316
148 31.290043 0.9146943413054633 -28.153318462063016
149 31.84735 0.9075799798911133 -28.510398893683746
150 32.240417 0.8682087851616029 -28.578065809607512
151 32.180126 0.900838307524805 -28.445030593428957
152 31.998781 0.9075921721250327 -28.513163864027362
153 32.003605 0.9100116603252515 -28.39934313925996
154 31.857119 0.9119978233936767 -29.066594372314093
155 32.561096 0.9024910872752789 -29.038974437597084
156 32.490494 0.9101649602444734 -28.886889872738173
157 32.295227 0.9099853847217502 -29.566824717529045
158 32.935562 0.9147736555772045 -29.68063156072008
159 33.01923 0.914564046020627 -29.41593166009342
160 32.693607 0.9098399540692363 -29.285230477929353
161 32.48901 0.904811649015421 -29.7503420543099
162 32.947327 0.8748841171170313 -29.68219213928569
163 32.866703 0.9171984435914976 -29.928895009431674
164 33.01393 0.9122909455632955 -30.30662992168904
165 33.41035 0.9074694802850708 -30.485532662255974
166 33.553406 0.8960396098368664 -30.278241684491498
167 33.318344 0.9123880461170936 -30.486436896562335
168 33.430973 0.9339685979529423 -30.191393828643783
169 33.125484 0.8886233746152078 -30.673051760330118
170 33.618843 0.9314895944516436 -31.196794469934837
171 34.06396 0.9193193040370262 -30.85431079793869
172 33.745872 0.8933664823506843 -31.391170458718502
173 34.29376 0.6422618884683352 -31.466511595441556
174 34.33469 0.9124012385137646 -31.927293246976138
175 34.764214 0.9336874597484772 -31.98390602481236
176 34.774918 0.9147989173794299 -31.386858646400647
177 34.182602 0.9098195260567447 -31.761808484492335
178 34.490993 0.9072116490909252 -32.103412823672215
179 34.854267 0.9099229519732478 -32.16684725785204
180 34.85346 0.9056609651567761 -32.47507091042791
181 35.1162 0.6506734638962807 -33.23024040486848
182 35.88411 0.869915553236498 -32.24347790050496
183 34.91281 0.926610563522746 -32.31530692055153
184 34.945496 0.9241748169046513 -33.02028321699302
185 35.56103 0.8960133758452035 -32.35111637638646
186 34.89423 0.928474779057068 -33.23602466379901
187 35.805717 0.9171710919094824 -32.59793144492862
188 35.22651 0.91992098787889 -33.074206762892125
189 35.561314 0.9360366816771034 -33.31097100301381
190 35.782406 0.907720436560557 -33.533532252302116
191 36.057888 0.9103384532191316 -33.09082720661792
192 35.530277 0.8984632561676342 -33.80330196825381
193 36.278805 0.8981731626922277 -33.18810937289483
194 35.593258 0.6661575506889443 -34.428961595964644
195 36.896317 0.9243442624978808 -34.29750976013116
196 36.76949 0.9217977509127361 -34.93790829051595
197 37.3182 0.9268233524391494 -34.75639725790526
198 37.1628 0.9217682310329729 -34.684709843056744
199 37.04298 0.9364708284413749 -34.67498021881172
200 36.9522 0.9314883692666568 -34.97541945482416
201 37.29302 0.9051297553709508 -34.58935491097996
202 36.84445 0.917195607204032 -35.56494030266732
203 37.8752 0.9261660138052312 -35.24118474247582
204 37.530926 0.9121773443712125 -35.46087444271772
205 37.71016 0.688185524800524 -35.56620879342346
206 37.804 0.9122784306977619 -36.337279976877575
207 38.605953 0.9216619231304548 -35.7220854203205
208 37.98588 0.6879699312151352 -35.380218334394016
209 37.589195 0.9266579929100349 -36.31389272266184
210 38.455853 0.9287665648834831 -35.79633637527839
211 37.988323 0.9148576416982418 -36.87690825134445
212 39.08839 0.931334875465541 -36.25540840905351
213 38.474453 0.9359645389700805 -36.62131958756222
214 38.824875 0.9360982186792379 -36.99300285560583
215 39.16117 0.919502187214348 -37.72916018376232
216 39.915825 0.9196336520298514 -37.276088142878145
217 39.414993 0.9361141518654124 -36.948269481086236
218 39.071304 0.9217746537258305 -36.847579406266775
219 38.944492 0.9217036436939395 -37.37577390053825
220 39.50542 0.9314410832092205 -36.898736892106854
221 38.98637 0.9362485743323443 -37.29534145923264
222 39.3999 0.9410787947012804 -38.57835454939353
223 40.628113 0.9408607363433215 -37.962280345241524
224 39.981403 0.9387840341351679 -38.41155579729448
225 40.47344 0.9410082120740806 -38.517371395420284
226 40.50992 0.905073117126198 -38.330723559230904
227 40.403015 0.9241754339995079 -38.82376843423213
228 40.837246 0.9169308117585409 -38.63620495396037
229 40.664013 0.9386896822080206 -38.90706423092294
230 40.896824 0.9339068463683333 -38.696793116982434
231 40.702633 0.9145780311554074 -38.65284400771682
232 40.62579 0.9006325662796982 -39.637494454352314
233 41.626167 0.9338129447333415 -39.78114728801758
234 41.796886 0.9242801144903465 -39.65097847274954
235 41.57147 0.9338876655063292 -39.77878315731534
236 41.68901 0.9384326516475999 -40.16391482862492
237 42.156246 0.933748701808647 -40.17643420414302
238 42.09112 0.9195355828867677 -40.69882361771501
239 42.6522 0.9433070222502209 -39.88425628693875
240 41.795567 0.9312496313482089 -40.06667512071424
241 41.96842 0.6649959661420343 -40.25920371553695
242 42.185585 0.933840538984267 -39.968183863842874
243 41.943356 0.9434117715961123 -41.71691005353176
244 43.644 0.9361124592580883 -41.13699389861304
245 43.023914 0.9338435982410622 -41.28230195762135
246 43.193844 0.9239637685647102 -41.560060418049325
247 43.493023 0.9289605367593431 -41.59855809952668
248 43.51518 0.9146019879654815 -41.36367037269284
249 43.246273 0.9217172360497119 -41.47216798973369
250 43.336735 0.9120347967280558 -42.23140674910907
251 44.096336 0.919600958481893 -41.837013226281165
252 43.689465 0.9168157005405911 -42.09410495072614
253 43.97186 0.9313324891094498 -43.33269437968738
254 45.19633 0.9103041900176997 -42.74235257628621
255 44.50436 0.9316946553597163 -42.6459909477362
256 44.48472 0.94340699206266 -43.07054047866841
257 44.94371 0.9266275445921005 -43.677742960288704
258 45.51856 0.9314846136123487 -43.152400611873844
259 45.022953 0.9295360345348367 -43.85954580481574
260 45.649456 0.941096952293478 -43.238502375844305
261 45.010784 0.9409037386367854 -44.1959112758573
262 45.964397 0.9077044035477514 -43.272662943668806
263 45.04781 0.9312846816610567 -43.95372238782308
264 45.766148 0.9265385942747207 -43.79405790988835
265 45.61018 0.943503887562348 -43.60608123459338
266 45.40575 0.9482943023231649 -43.73149435438444
267 45.533237 0.9362164782704477 -44.456164408531464
268 46.25102 0.675526897066116 -44.91974124133422
269 46.70814 0.9386585174647938 -45.20236344500999
270 46.91764 0.93619619611082 -44.87311225838185
271 46.6652 0.9146629836113067 -44.77760838980033
272 46.537632 0.9338646365339074 -44.5743114139252
273 46.369617 0.9148357574921085 -45.31227919633451
274 47.010086 0.9288605841947968 -44.80082827942532
275 46.527664 0.9336166790243395 -44.824010647628135
276 46.56868 0.9169747400866328 -45.68804309217637
277 47.382 0.9312534770693748 -45.72352629848309
278 47.445965 0.9458445017987467 -45.68033726958641
279 47.387566 0.9411043389165059 -45.28531904494126
280 47.063156 0.9409118263306525 -46.07666833832236
281 47.77425 0.9435329156867446 -46.02638579461601
282 47.709824 0.9243080172423859 -45.56032259622697
283 47.30043 0.9434939869396793 -46.44801472012902
284 48.195957 0.9386534233274456 -46.71075045235024
285 48.422726 0.9338023676783402 -47.018610640639295
286 48.772133 0.940949919907394 -47.12493735282759
287 48.82328 0.9289707251709508 -47.17421957635267
288 48.87887 0.9433070222502209 -46.8329308939545
289 48.56754 0.9339093412701246 -46.83154870063826
290 48.513634 0.9482811955528008 -47.114446420883745
291 48.80567 0.943448994813494 -46.77148101741365
292 48.43017 0.9127387457555893 -47.812212565238646
293 49.537113 0.9459554529871025 -47.877821717117726
294 49.573853 0.9361937876461894 -48.151367839926486
295 49.842133 0.9292136528480338 -47.717037035109826
296 49.411465 0.9314517570811924 -48.11062758982222
297 49.814972 0.9312899682618766 -47.53576476351081
298 49.222496 0.9146626475861207 -48.04822511991112
299 49.701046 0.9384967847679025 -48.58175595749809
300 50.267567 0.9361233151078371 -49.155470433388636
301 50.83741 0.9530027920396302 -49.069428836087674
302 50.709015 0.917043442144682 -49.15210042853091
303 50.777473 0.9433968498615068 -49.595696934693784
304 51.26062 0.9384554245612082 -49.23064757322961
305 50.889492 0.9264896161164528 -49.762914131546964
306 51.405636 0.9242478273129594 -50.39477308448899
307 52.04346 0.9411936815457121 -48.88187845913276
308 50.469208 0.9268086724349757 -49.52015842460787
309 51.11294 0.931601310868745 -50.341840874423134
310 51.941063 0.9267573710458545 -49.67789876554459
311 51.296425 0.6822948807128411 -50.10123211311376
312 51.71586 0.9362690826160238 -50.61006503534119
313 52.24219 0.9338174702952926 -50.964289175170606
314 52.56595 0.9245980648792232 -49.86471802759799
315 51.452152 0.9315344304664925 -50.532457677194245
316 52.09473 0.9340373735001957 -50.60457752373329
317 52.186676 0.9169286143512682 -50.52824718691812
318 52.117687 0.9532653343308853 -50.96577980050135
319 52.553677 0.938578592064411 -51.0455748103674
320 52.595074 0.9362616050374165 -51.07663928798285
321 52.705574 0.9483645253012802 -51.52745472623567
322 53.113655 0.9289600004461873 -52.10178824000801
323 53.655758 0.9363359401344202 -51.79856635900364
324 53.396515 0.943281922979496 -51.96477465996567
325 53.52005 0.9360412715208521 -51.92645560684186
326 53.539406 0.9483098080886775 -51.38460532306398
327 52.96323 0.9434038428767193 -52.73174277648113
328 54.274914 0.9409543776131868 -52.86808538836149
329 54.414093 0.9241566366137663 -51.96624843494258
330 53.514637 0.9314256239321634 -52.25812006328773
331 53.80943 0.9123081283435847 -53.37812675284057
332 54.911587 0.9385990380860523 -52.37090489707615
333 53.952515 0.9579839794145553 -53.80769842632409
334 55.356075 0.9362773020521817 -53.22752059642615
335 54.766113 0.9196475097324497 -52.70105835588845
336 54.22362 0.933786296164048 -53.88103299918548
337 55.38281 0.9241819503751091 -53.79245257055031
338 55.321426 0.9339725860360755 -53.23907889397464
339 54.719875 0.9361013573651175 -53.55862202247147
340 55.056675 0.9410798597561829 -54.54624748743109
341 56.045376 0.9384642509919361 -54.290690354407715
342 55.84019 0.934013768170251 -54.37227055875233
343 55.82118 0.9411092297361782 -55.13394682197916
344 56.65467 0.9289731936906593 -54.45063631505906
345 55.937954 0.9457481311257834 -55.44111763155826
346 56.93847 0.9170716964756597 -54.008481631069905
347 55.50485 0.948337475801075 -55.85371567990805
348 57.381233 0.9389332438030509 -54.21278343345174
349 55.72363 0.9360277665886446 -56.36200450542284
350 57.926163 0.9385442036382121 -55.287349617799606
351 56.78759 0.9385688256596972 -54.98875789338471
352 56.47116 0.940906861329766 -55.46727402753033
353 56.927868 0.9265359248007349 -55.95690627681954
354 57.44238 0.9315254995164585 -55.90495707323904
355 57.372593 0.9383810801287759 -55.151627522555124
356 56.613277 0.9385497709620318 -56.583374575557535
357 58.040905 0.9289206400998612 -57.15375206458828
358 58.6588 0.938839813813307 -56.911435491045104
359 58.396004 0.9506629710160279 -56.67874534266195
360 58.077644 0.9336939497134896 -57.306086903379736
361 58.794327 0.9239713089943248 -57.246686351934656
362 58.715927 0.9388328058910975 -57.47209943351888
363 58.92101 0.9363103601110734 -56.88259152303607
364 58.360832 0.9409458416111852 -57.866559304079786
365 59.32959 0.9289849894538038 -57.61979434937067
366 59.0752 0.919502506767974 -57.6007705067422
367 59.061634 0.9289135610426666 -57.98020929897457
368 59.40823 0.9507884626324186 -58.311167046903996
369 59.7086 0.9361686893229114 -59.39663606416072
370 60.845245 0.9316117947380224 -59.371685121169286
371 60.808956 0.9336126699285775 -58.3735555935064
372 59.83399 0.9289408034905836 -58.70765114684771
373 60.12649 0.9386334413532126 -58.79655104446919
374 60.199677 0.9411078959568269 -58.37978323756041
375 59.746197 0.9193302706682228 -58.158461247723665
376 59.517582 0.9361657977883846 -58.546475525862455
377 59.88474 0.941169473708928 -59.53714352573999
378 60.936394 0.9387133647247262 -58.89642004203374
379 60.307255 0.9409592803122776 -60.08621223649035
380 61.463406 0.9244480572174005 -60.25164412007419
381 61.627808 0.9385516383034639 -59.65470634874928
382 61.08458 0.9361042516110605 -59.658306001013536
383 61.058113 0.941038654307516 -59.357202616515686
384 60.75208 0.9313708523204801 -59.39215362242612
385 60.795376 0.9335584547105307 -60.92718077517706
386 62.32712 0.9338179989175789 -60.17290510762455
387 61.534187 0.9314299616780156 -60.56431999514331
388 61.925266 0.9458537806337974 -61.94726721471983
389 63.304695 0.9435795099202235 -60.387911320599784
390 61.779427 0.9363378626933384 -60.95682530783498
391 62.308586 0.9410536951463077 -62.038077348808514
392 63.43991 0.9457895324807266 -60.76609456964016
393 62.10059 0.9386046172118506 -61.43661312126889
394 62.799294 0.9338445971869483 -62.194463167186484
395 63.546772 0.9290592059910425 -61.34540248224434
396 62.686554 0.9337772738071038 -61.92923315461959
397 63.255863 0.9170651085588399 -61.94406577986058
398 63.29587 0.9386422848742059 -61.438265160148894
399 62.744526 0.9337356527720342 -63.321828239387976
400 64.6676 0.9171902704737648 -62.18541259205495
401 63.516994 0.9314962637890593 -62.54184636705649
402 63.861813 0.9339469478811487 -62.86847780473726
403 64.18908 0.6772017320486903 -62.909016008837916
404 64.23691 0.9195734347310877 -62.97319920085954
405 64.313446 0.9338988247941281 -62.641605360589296
406 63.92783 0.9339208927778799 -63.31623631940531
407 64.6377 0.9383006907406831 -64.57479770312608
408 65.90879 0.9386483283439363 -63.16528180991833
409 64.48343 0.9386139039436976 -62.484030683391175
410 63.737762 0.9339143685893947 -62.7564247961461
411 64.0631 0.9388207193382838 -63.982445512911205
412 65.2788 0.9384312579252562 -63.552321233206044
413 64.84384 0.941205947278271 -63.340142906776194
414 64.626945 0.9384893554651378 -64.24998175429047
415 65.51272 0.9288866528249509 -64.3691781845215
416 65.62501 0.9316821061507065 -64.23243505147533
417 65.521965 0.9507055905467419 -63.84915206111708
418 65.11914 0.9458303169073627 -63.92649566963884
419 65.23108 0.9579887370969611 -66.03297231283726
420 67.35751 0.9531433760568697 -66.27176425240023
421 67.49118 0.9530813412583439 -63.987338741355515
422 65.2738 0.9481835336088783 -65.2125453335299
423 66.529686 0.9386329780482606 -65.90432165891632
424 67.151344 0.9384484519870665 -65.03557589601297
425 66.26163 0.9384489166097247 -64.8865337249148
426 66.15624 0.9198374684364394 -65.86238122595837
427 67.0713 0.9338455961026672 -66.14336286195864
428 67.421875 0.9410978417221632 -65.36567428237997
429 66.59523 0.9435714166509804 -66.32997067481278
430 67.63621 0.9147511319719784 -66.19433843608725
431 67.41604 0.9197764559456639 -65.8439740410718
432 67.09952 0.929545391496548 -66.98199880744104
433 68.1913 0.9337979244654195 -65.96527017957828
434 67.244026 0.941074697720604 -66.37710333669884
435 67.602905 0.9217525125973284 -67.74361817203979
436 68.90475 0.9385762734603942 -67.55862885241712
437 68.80382 0.9387161691356409 -67.0004234336046
438 68.139656 0.9292589078596195 -67.44964640283584
439 68.61462 0.9337436779006592 -66.17002008030208
440 67.42282 0.9411585358768727 -67.22718152185865
441 68.45496 0.9313694121957602 -67.01996477501709
442 68.25402 0.9336688575468722 -67.91837407715634
443 69.15321 0.9338505902288008 -68.20088802209115
444 69.450165 0.9337341519649334 -67.30721314209967
445 68.487305 0.9288366592443782 -67.53832828384122
446 68.792366 0.9506415990524855 -68.7820096565673
447 69.98891 0.9191909857330078 -68.40369372062699
448 69.56446 0.9290320799251226 -67.26062333777375
449 68.49412 0.9267218492568214 -67.63568832965665
450 68.84266 0.9386529601820951 -67.41922834747929
451 68.58336 0.9193893585415871 -68.22671761860103
452 69.4336 0.9314536328919572 -68.4287177588607
453 69.59918 0.9216265158801866 -69.34136190841976
454 70.518105 0.9147348937978166 -69.84342882250567
455 70.950745 0.9314601664680866 -69.77049582961781
456 70.966354 0.9336538346310982 -70.54906453282715
457 71.70278 0.9219217688476227 -68.46819294255194
458 69.641235 0.9266098440848324 -69.88995761234622
459 71.05608 0.9337055134183398 -69.12690967394182
460 70.289055 0.9191032692161769 -69.4177785810578
461 70.609924 0.9482770683557084 -69.49131457860922
462 70.644104 0.9386152941331335 -69.31184194773087
463 70.49469 0.9458380228163248 -71.02913305874603
464 72.202065 0.9217244650362648 -71.00399700979757
465 72.147125 0.9241146740231106 -69.91240351498041
466 71.03903 0.9434610327742449 -69.89451136568607
467 71.04827 0.9360301730415713 -70.36350021174356
468 71.52441 0.9240795781734016 -70.6523114951764
469 71.841255 0.9433706255170797 -70.97800925196775
470 72.14248 0.936518060472732 -70.34140671796908
471 71.41308 0.6700125883845816 -71.37415581564974
472 72.48785 0.9432194857506386 -71.06887074495054
473 72.24132 0.6799191328717841 -72.6891986923582
474 73.81839 0.9409304272378429 -70.86523599575055
475 72.000565 0.9312640428367959 -71.71540876267521
476 72.87248 0.9409866763388373 -72.24009688646504
477 73.38629 0.9436687271547698 -71.5825211778181
478 72.69054 0.926685397593051 -72.76757627835993
479 73.88097 0.9265821387973513 -71.5544621509566
480 72.648766 0.9288098558361103 -71.09213235540517
481 72.20597 0.9458307258727466 -72.04496405793456
482 73.15429 0.9312569899960574 -72.789000667466
483 73.91122 0.9244934086446598 -72.70620548973753
484 73.79589 0.9291176658806618 -72.65343210596274
485 73.81641 0.9315298218400299 -73.14929774965174
486 74.2042 0.9265161649294796 -72.95100028294499
487 74.01866 0.9434752885678902 -74.44905302395495
488 75.49894 0.9433805336820021 -74.15705425672948
489 75.20378 0.9313477380662144 -73.06826244177253
490 74.16461 0.9243918675971277 -74.00443576762638
491 75.08786 0.9361491425082101 -73.25847584660967
492 74.355576 0.9361365932044579 -73.64635874598994
493 74.724525 0.9145690577365322 -73.44005104561197
494 74.54677 0.924213343391225 -73.47182932922905
495 74.57868 0.9217783441741852 -74.35021747514558
496 75.396805 0.9362565334543848 -74.50758929101985
497 75.537766 0.9433921498602896 -75.04858182848967
498 76.152336 0.9362876658572992 -74.06536623563017
499 75.100006 0.9557846980030433 -75.07351357248896
In [11]:
import matplotlib as mpl

plt.figure(figsize=(10,10))
ax = plt.axes()
for n in set(gmm.predict(latent)):
    covariances = gmm.covariances_[n][:2, :2]
    v, w = np.linalg.eigh(covariances)
    u = w[0] / np.linalg.norm(w[0])
    angle = np.arctan2(u[1], u[0])
    angle = 180 * angle / np.pi  # convert to degrees
    v = 2. * np.sqrt(2.) * np.sqrt(v)
    ell = mpl.patches.Ellipse(gmm.means_[n, :2], v[0], v[1],
                              180 + angle, color='k')
    ell.set_alpha(0.2)
    ax.add_patch(ell)
    
pylab.scatter(latent[:,0],latent[:,1],c=gmm.predict(latent))
Out[11]:
<matplotlib.collections.PathCollection at 0x125f6dfd0>
In [12]:
plt.figure(figsize=(10,10))
pylab.scatter(VAE_predict_mean.predict(data)[:,0],VAE_predict_mean.predict(data)[:,1],c=label)
Out[12]:
<matplotlib.collections.PathCollection at 0x125f2a358>
In [13]:
gmm.predict(latent)
Out[13]:
array([4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 1, 4, 4,
       4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
       4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
       4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
       4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
       4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
       4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
       4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
       4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
       4, 4, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 2, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2, 2, 0, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
       3, 3, 3, 3, 3, 3, 3, 3, 4, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
       3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
       3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
       3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
       3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
       3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
       3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
       3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 3, 3, 3, 3, 3, 3, 3, 3, 4,
       3, 3, 3, 3, 3, 3, 3, 3, 3, 4])
In [14]:
import pandas as pd
result  = gmm.predict(latent)
df      = pd.DataFrame({ 'result' : gmm.predict(latent),'label' : label})
iremono = np.zeros((5,5))
for i in range(5):
    acc = df[df['label']==i]
    for j in range(5):
        iremono[i,j] = len(acc[acc['result']==j])
acc = iremono.max(1).sum()/len(label)
print(acc)   
0.982