import sklearn.mixture
import numpy as np
import scipy as sp
from keras import Input
from keras.layers import Dense, Lambda, Layer, Activation
from keras.models import Model
from keras.layers.normalization import BatchNormalization
import keras.backend as K
import matplotlib.pyplot as plt
import autograd.numpy.random as npr
import pylab
from keras import backend
from sklearn.metrics import adjusted_rand_score
%matplotlib inline
from sklearn.metrics import accuracy_score
/Users/tomohiromimura/.pyenv/versions/3.6.4/envs/normal/lib/python3.6/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`. from ._conv import register_converters as _register_converters Using TensorFlow backend.
def make_pinwheel_data(radial_std, tangential_std, num_classes, num_per_class, rate):
rads = np.linspace(0, 2*np.pi, num_classes, endpoint=False)
features = npr.randn(num_classes*num_per_class, 2) \
* np.array([radial_std, tangential_std])
features[:,0] += 1.
labels = np.repeat(np.arange(num_classes), num_per_class)
angles = rads[labels] + rate * np.exp(features[:,0])
rotations = np.stack([np.cos(angles), -np.sin(angles), np.sin(angles), np.cos(angles)])
rotations = np.reshape(rotations.T, (-1, 2, 2))
return 10*np.einsum('ti,tij->tj', features, rotations),labels
num_clusters = 5 # number of clusters in pinwheel data
samples_per_cluster = 200 # number of samples per cluster in pinwheel
K = 15 # number of components in mixture model
N = 2 # number of latent dimensions
P = 2 # number of observation dimensions
batch_size = num_clusters * samples_per_cluster
data, label = make_pinwheel_data(0.3, 0.05, num_clusters, samples_per_cluster, 0.25)
plt.figure(figsize=(10,10))
pylab.scatter(data[:,0],data[:,1],c=label)
<matplotlib.collections.PathCollection at 0x11b2cbfd0>
def vae_loss(y_true, y_pred):
# 入力と出力の交差エントロピー
xent_loss = backend.sum((backend.square(y_true - y_mean))/(backend.exp(y_log_var)) + backend.log(np.pi * 2) + y_log_var, axis=-1) * 0.5
# 事前分布と事後分布のKL情報量
kl_loss = backend.sum((backend.square(z - z_mean))/(backend.exp(z_log_var)) + backend.log(np.pi * 2) + z_log_var, axis=-1) * 0.5
return backend.mean(xent_loss - kl_loss - gmm_loss)
def sampling(args):
mean, log_var = args
epsilon = backend.random_normal(shape=backend.shape(mean), mean=0.,stddev=1.)
return mean + backend.exp(log_var / 2) * epsilon
original_dim = data.shape[1]
intermediate_dim = 100
latent_dim = 2
x = Input(shape=(original_dim,))
gmm_loss = Input(shape=(1,))
decoder = Dense(intermediate_dim)(x)
decoder = BatchNormalization()(decoder)
decoder = Activation('tanh')(decoder)
z_mean = Dense(latent_dim,activation='linear')(decoder)
z_log_var = Dense(latent_dim,activation='linear')(decoder)
z = Lambda(sampling, output_shape=(latent_dim,))([z_mean, z_log_var])
decoder = Dense(intermediate_dim)(z)
decoder = BatchNormalization()(decoder)
decoder = Activation('tanh')(decoder)
y_mean = Dense(original_dim,activation='linear')(decoder)
y_log_var = Dense(original_dim,activation='linear')(decoder)
y = Lambda(sampling, output_shape=(original_dim,))([y_mean, y_log_var])
VAE = Model(input=[x,gmm_loss], output=y)
VAE_latent = Model(input=x, output=z)
VAE_latent_mean = Model(input=x, output=z_mean)
VAE_latent_var = Model(input=x, output=z_log_var)
VAE_predict_mean = Model(input=x, output=y_mean)
VAE_predict_var = Model(input=x, output=y_log_var)
VAE.compile(optimizer='adam', loss=vae_loss)
VAE.summary()
__________________________________________________________________________________________________ Layer (type) Output Shape Param # Connected to ================================================================================================== input_1 (InputLayer) (None, 2) 0 __________________________________________________________________________________________________ dense_1 (Dense) (None, 100) 300 input_1[0][0] __________________________________________________________________________________________________ batch_normalization_1 (BatchNor (None, 100) 400 dense_1[0][0] __________________________________________________________________________________________________ activation_1 (Activation) (None, 100) 0 batch_normalization_1[0][0] __________________________________________________________________________________________________ dense_2 (Dense) (None, 2) 202 activation_1[0][0] __________________________________________________________________________________________________ dense_3 (Dense) (None, 2) 202 activation_1[0][0] __________________________________________________________________________________________________ lambda_1 (Lambda) (None, 2) 0 dense_2[0][0] dense_3[0][0] __________________________________________________________________________________________________ dense_4 (Dense) (None, 100) 300 lambda_1[0][0] __________________________________________________________________________________________________ batch_normalization_2 (BatchNor (None, 100) 400 dense_4[0][0] __________________________________________________________________________________________________ activation_2 (Activation) (None, 100) 0 batch_normalization_2[0][0] __________________________________________________________________________________________________ dense_5 (Dense) (None, 2) 202 activation_2[0][0] __________________________________________________________________________________________________ dense_6 (Dense) (None, 2) 202 activation_2[0][0] __________________________________________________________________________________________________ lambda_2 (Lambda) (None, 2) 0 dense_5[0][0] dense_6[0][0] ================================================================================================== Total params: 2,208 Trainable params: 1,808 Non-trainable params: 400 __________________________________________________________________________________________________
/Users/tomohiromimura/.pyenv/versions/3.6.4/envs/normal/lib/python3.6/site-packages/ipykernel_launcher.py:29: UserWarning: Update your `Model` call to the Keras 2 API: `Model(inputs=[<tf.Tenso..., outputs=Tensor("la...)` /Users/tomohiromimura/.pyenv/versions/3.6.4/envs/normal/lib/python3.6/site-packages/ipykernel_launcher.py:30: UserWarning: Update your `Model` call to the Keras 2 API: `Model(inputs=Tensor("in..., outputs=Tensor("la...)` /Users/tomohiromimura/.pyenv/versions/3.6.4/envs/normal/lib/python3.6/site-packages/ipykernel_launcher.py:32: UserWarning: Update your `Model` call to the Keras 2 API: `Model(inputs=Tensor("in..., outputs=Tensor("de...)` /Users/tomohiromimura/.pyenv/versions/3.6.4/envs/normal/lib/python3.6/site-packages/ipykernel_launcher.py:33: UserWarning: Update your `Model` call to the Keras 2 API: `Model(inputs=Tensor("in..., outputs=Tensor("de...)` /Users/tomohiromimura/.pyenv/versions/3.6.4/envs/normal/lib/python3.6/site-packages/ipykernel_launcher.py:34: UserWarning: Update your `Model` call to the Keras 2 API: `Model(inputs=Tensor("in..., outputs=Tensor("de...)` /Users/tomohiromimura/.pyenv/versions/3.6.4/envs/normal/lib/python3.6/site-packages/ipykernel_launcher.py:35: UserWarning: Update your `Model` call to the Keras 2 API: `Model(inputs=Tensor("in..., outputs=Tensor("de...)`
from scipy.stats import multivariate_normal
epoch = 500
latent = VAE_latent.predict(data)
gmm = sklearn.mixture.GaussianMixture(n_components=num_clusters, covariance_type='full', max_iter=1)
for i in range(1):
gmm.fit(latent)
pi_i = gmm.predict_proba(latent)
mu_ = gmm.means_
sigma_ = gmm.covariances_
liklyhood = np.zeros(data.shape[0])
for m,s,pi in zip(mu_,sigma_,pi_i.T):
liklyhood += pi*multivariate_normal.logpdf(latent,m,s)
for i in range(epoch):
loss = VAE.train_on_batch([data,liklyhood],data)
latent = VAE_latent.predict(data)
gmm.fit(latent)
liklyhood = np.zeros(data.shape[0])
for m,s,pi in zip(mu_,sigma_,pi_i.T):
liklyhood += pi*multivariate_normal.logpdf(latent,m,s)
print(i,loss,adjusted_rand_score(gmm.predict(latent),label),liklyhood.mean())
/Users/tomohiromimura/.pyenv/versions/3.6.4/envs/normal/lib/python3.6/site-packages/sklearn/mixture/base.py:237: ConvergenceWarning: Initialization 1 did not converge. Try different init parameters, or increase max_iter, tol or check for degenerate data. % (init + 1), ConvergenceWarning)
0 103.41993 0.14651976688765464 -4.642745632400823 1 95.59113 0.12824113232861528 -4.768401059745967 2 87.78585 0.1427903809605796 -5.087102852087367 3 79.762024 0.12900684280492689 -5.095885069702208 4 73.61033 0.14716991183681807 -5.224178837160909 5 69.6406 0.1516915803066779 -5.321435847776818 6 63.365334 0.17547067308204586 -5.448860569112303 7 61.06198 0.15748745211492618 -5.547109175062757 8 57.81444 0.16493515084376786 -5.875019881697466 9 55.200344 0.17322788825720362 -5.922366681168228 10 53.785786 0.1994689807213474 -6.31613884267172 11 52.853863 0.20554879153442485 -6.427000457735698 12 50.0747 0.19929126174245684 -6.586035789843281 13 49.4235 0.2444029770832334 -7.113938276925496 14 48.939808 0.20281728465848078 -7.0752181247013874 15 48.175327 0.2479702379579886 -7.197497229914569 16 47.22448 0.24044023733156006 -7.74776051846974 17 47.204098 0.23132759000805986 -7.727132936996007 18 46.164593 0.25417393033220814 -8.361216120912708 19 46.21272 0.2623205017144226 -8.257024025246741 20 45.360035 0.2679299297823144 -8.432337060020489 21 45.0898 0.26341299480972885 -8.955047223624115 22 45.183117 0.2847788815989396 -8.986750684028825 23 44.341305 0.304357909891215 -9.118678806610129 24 44.023094 0.3049789727679779 -9.133409513398702 25 43.354504 0.298181801252176 -9.771252349173988 26 43.328617 0.3313741968663225 -9.46552848005264 27 42.637848 0.2933899887390654 -10.206163685103087 28 42.854668 0.2824323410859363 -10.098439694641483 29 42.127445 0.3350364692008074 -10.074408098165582 30 41.744812 0.3441112057140879 -10.648393606124756 31 41.759315 0.31496667977320336 -10.712282032424996 32 41.30469 0.29374557885406877 -10.91398448299344 33 41.093445 0.35681249754046634 -10.757477012351181 34 40.302723 0.3352114411684572 -10.701082358232117 35 39.955578 0.3217871236958889 -11.22482994338767 36 39.99076 0.3064558482089871 -11.237795999159438 37 39.449543 0.3254595325436984 -11.367971451172824 38 39.341152 0.33066540446604475 -11.776637032519774 39 39.379684 0.3520454813402633 -11.470247045814979 40 38.41451 0.35313500056111435 -11.808370536369395 41 38.31462 0.312527838276921 -11.839242458480676 42 37.946274 0.3995640222961814 -12.105170865732658 43 37.631298 0.357406945164926 -11.976346107920254 44 37.019817 0.37389684570629234 -12.708690839465294 45 37.623817 0.3768489766593488 -12.691160643150921 46 36.988323 0.3721888606848408 -13.258830140540232 47 37.19482 0.38520590966352464 -12.850440961070111 48 36.291996 0.4304298830365427 -12.95546218889975 49 36.04172 0.40958138331844146 -12.902951721902774 50 35.56395 0.40012647428659726 -13.352228439009615 51 35.744576 0.43068868984081443 -13.31682239217566 52 35.11469 0.43498224653250067 -13.392760720652564 53 34.971016 0.46538327211209746 -13.543084364713444 54 34.81939 0.45541808830069896 -14.209355551856541 55 35.005676 0.45952417621804315 -14.474968799058402 56 34.911633 0.4587407957264953 -14.402441757502897 57 34.540745 0.4817640966533137 -14.555936759869509 58 34.306637 0.5051612867073974 -14.775263593105267 59 34.190426 0.5279150869090891 -14.726068002626125 60 33.74902 0.49095220499003067 -15.274093792006092 61 33.892635 0.4503849097359772 -15.233248190956688 62 33.461246 0.592579195162078 -15.18504666062564 63 33.273697 0.5311190502341395 -15.230914296184357 64 33.029484 0.616360764946402 -15.587959826136379 65 32.88036 0.6172733434788995 -15.707476731163977 66 32.601303 0.5739292095600497 -16.11516390670334 67 32.833195 0.7025179919737342 -16.136710411741138 68 32.41162 0.6681456464802271 -16.359390852426053 69 32.217342 0.6572200271280586 -16.506623499510297 70 32.15791 0.616570248076777 -16.69191646860091 71 31.975357 0.6508899274075721 -16.767368705439125 72 31.919617 0.6881436464011886 -17.198648628042072 73 32.01978 0.6750837457359226 -17.15287093219415 74 31.779526 0.6918735806427492 -17.535223076142508 75 31.734312 0.7292373063569662 -17.546839876061203 76 31.394787 0.7247484048027135 -17.85347662678195 77 31.499252 0.6853985535919548 -18.16091017687199 78 31.508116 0.7349880420462164 -18.114569214957857 79 31.26149 0.7423769075505395 -18.779482900153415 80 31.529526 0.7470846068566828 -18.810700891083854 81 31.366941 0.7224087163678539 -19.358948249874157 82 31.610304 0.7912928349563398 -19.049027507685462 83 31.25202 0.7860326418541864 -18.9841620307431 84 30.79375 0.7794879457044368 -18.918911108164814 85 30.554947 0.7729100812809085 -19.20043552859653 86 30.642025 0.751684938920648 -19.32328037807464 87 30.525583 0.7823977889876562 -19.54307237088586 88 30.475044 0.7733421621124993 -19.655804229592643 89 30.335756 0.6100574237272759 -19.19181382693341 90 29.599869 0.7911275556904153 -20.235269492411447 91 30.501158 0.7853013833657153 -20.48729988680503 92 30.608463 0.8255076217483036 -20.841053862008533 93 30.705582 0.8074868199114797 -20.680699265385524 94 30.211409 0.8114180020669893 -20.277373519202893 95 29.582516 0.8259621798364055 -21.12387102984487 96 30.403755 0.8208680269386627 -20.869931041043717 97 29.891912 0.8096274224133904 -21.403154765954874 98 30.297428 0.8257224119363835 -20.950070943038387 99 29.704586 0.8298270017033267 -21.75907162496246 100 30.329098 0.8226775969883754 -21.33020752691789 101 29.64697 0.8180399408140512 -21.685804283846753 102 29.933577 0.8410587941716282 -21.7832938220972 103 29.77149 0.8174365149883811 -21.914447368881238 104 29.740046 0.8316926111696454 -22.0562411135716 105 29.806938 0.8317390103536481 -22.616170894713623 106 30.21606 0.8257662151507682 -22.578665375031882 107 29.998287 0.831902781528124 -22.63572117770363 108 29.980158 0.8567000423284565 -22.578966722635645 109 29.694403 0.8559573638817676 -22.96378858344357 110 30.065737 0.8146205189155667 -23.538564052303837 111 30.44531 0.8504772518209407 -23.334574091034835 112 30.1668 0.8499100301191843 -23.355271732430868 113 30.00439 0.8816800992327803 -23.82643283273959 114 30.351254 0.8704435257268806 -24.379300469557332 115 30.837906 0.8654744408674866 -23.992831492809703 116 30.281492 0.8771201441331563 -24.30258774625447 117 30.433645 0.8430568413968007 -24.193753794898146 118 30.190172 0.8819701597824047 -24.206014705524403 119 30.08351 0.8984660804866692 -24.318340916619256 120 30.183401 0.6292624939866924 -24.424072052119243 121 30.207083 0.8661644136773614 -24.425884975480553 122 30.132326 0.8439893360166838 -25.03229175762851 123 30.647522 0.8543628340837299 -24.696700128585203 124 30.135061 0.8886634471791776 -24.740907437579676 125 30.14606 0.8544594782887868 -25.12309035206863 126 30.399416 0.8769993662081029 -25.366146961600275 127 30.585445 0.8728957045242338 -25.835406509532216 128 30.914303 0.8774428347002621 -25.83783393471013 129 30.81512 0.8816079733718615 -25.850730226330057 130 30.847914 0.872977576349313 -26.079237589488496 131 30.905712 0.8861396547623495 -25.77705368727466 132 30.563902 0.8704937049471267 -26.14619757126324 133 30.812027 0.8750488408319063 -26.056442470836032 134 30.669239 0.8698272815488811 -26.029308435336134 135 30.626171 0.8914968532285611 -27.266426869315183 136 31.820328 0.8844745263033172 -26.92858053027846 137 31.370865 0.8754189772700823 -26.314381623393917 138 30.68016 0.886656686205439 -26.61959495337881 139 30.959408 0.9003565107749821 -27.275590168946533 140 31.447756 0.8912978112471132 -26.974919988135994 141 31.060806 0.8770648436681338 -27.315778715837606 142 31.502014 0.9145431755694918 -27.756548926950497 143 31.786802 0.8932503267172457 -27.151328056172446 144 31.135513 0.9056926942914326 -27.302334512345414 145 31.227976 0.8798300593183281 -27.385711123808193 146 31.358395 0.8796832637364902 -28.417603782630735 147 32.316177 0.6777675686638298 -27.51396937259316 148 31.290043 0.9146943413054633 -28.153318462063016 149 31.84735 0.9075799798911133 -28.510398893683746 150 32.240417 0.8682087851616029 -28.578065809607512 151 32.180126 0.900838307524805 -28.445030593428957 152 31.998781 0.9075921721250327 -28.513163864027362 153 32.003605 0.9100116603252515 -28.39934313925996 154 31.857119 0.9119978233936767 -29.066594372314093 155 32.561096 0.9024910872752789 -29.038974437597084 156 32.490494 0.9101649602444734 -28.886889872738173 157 32.295227 0.9099853847217502 -29.566824717529045 158 32.935562 0.9147736555772045 -29.68063156072008 159 33.01923 0.914564046020627 -29.41593166009342 160 32.693607 0.9098399540692363 -29.285230477929353 161 32.48901 0.904811649015421 -29.7503420543099 162 32.947327 0.8748841171170313 -29.68219213928569 163 32.866703 0.9171984435914976 -29.928895009431674 164 33.01393 0.9122909455632955 -30.30662992168904 165 33.41035 0.9074694802850708 -30.485532662255974 166 33.553406 0.8960396098368664 -30.278241684491498 167 33.318344 0.9123880461170936 -30.486436896562335 168 33.430973 0.9339685979529423 -30.191393828643783 169 33.125484 0.8886233746152078 -30.673051760330118 170 33.618843 0.9314895944516436 -31.196794469934837 171 34.06396 0.9193193040370262 -30.85431079793869 172 33.745872 0.8933664823506843 -31.391170458718502 173 34.29376 0.6422618884683352 -31.466511595441556 174 34.33469 0.9124012385137646 -31.927293246976138 175 34.764214 0.9336874597484772 -31.98390602481236 176 34.774918 0.9147989173794299 -31.386858646400647 177 34.182602 0.9098195260567447 -31.761808484492335 178 34.490993 0.9072116490909252 -32.103412823672215 179 34.854267 0.9099229519732478 -32.16684725785204 180 34.85346 0.9056609651567761 -32.47507091042791 181 35.1162 0.6506734638962807 -33.23024040486848 182 35.88411 0.869915553236498 -32.24347790050496 183 34.91281 0.926610563522746 -32.31530692055153 184 34.945496 0.9241748169046513 -33.02028321699302 185 35.56103 0.8960133758452035 -32.35111637638646 186 34.89423 0.928474779057068 -33.23602466379901 187 35.805717 0.9171710919094824 -32.59793144492862 188 35.22651 0.91992098787889 -33.074206762892125 189 35.561314 0.9360366816771034 -33.31097100301381 190 35.782406 0.907720436560557 -33.533532252302116 191 36.057888 0.9103384532191316 -33.09082720661792 192 35.530277 0.8984632561676342 -33.80330196825381 193 36.278805 0.8981731626922277 -33.18810937289483 194 35.593258 0.6661575506889443 -34.428961595964644 195 36.896317 0.9243442624978808 -34.29750976013116 196 36.76949 0.9217977509127361 -34.93790829051595 197 37.3182 0.9268233524391494 -34.75639725790526 198 37.1628 0.9217682310329729 -34.684709843056744 199 37.04298 0.9364708284413749 -34.67498021881172 200 36.9522 0.9314883692666568 -34.97541945482416 201 37.29302 0.9051297553709508 -34.58935491097996 202 36.84445 0.917195607204032 -35.56494030266732 203 37.8752 0.9261660138052312 -35.24118474247582 204 37.530926 0.9121773443712125 -35.46087444271772 205 37.71016 0.688185524800524 -35.56620879342346 206 37.804 0.9122784306977619 -36.337279976877575 207 38.605953 0.9216619231304548 -35.7220854203205 208 37.98588 0.6879699312151352 -35.380218334394016 209 37.589195 0.9266579929100349 -36.31389272266184 210 38.455853 0.9287665648834831 -35.79633637527839 211 37.988323 0.9148576416982418 -36.87690825134445 212 39.08839 0.931334875465541 -36.25540840905351 213 38.474453 0.9359645389700805 -36.62131958756222 214 38.824875 0.9360982186792379 -36.99300285560583 215 39.16117 0.919502187214348 -37.72916018376232 216 39.915825 0.9196336520298514 -37.276088142878145 217 39.414993 0.9361141518654124 -36.948269481086236 218 39.071304 0.9217746537258305 -36.847579406266775 219 38.944492 0.9217036436939395 -37.37577390053825 220 39.50542 0.9314410832092205 -36.898736892106854 221 38.98637 0.9362485743323443 -37.29534145923264 222 39.3999 0.9410787947012804 -38.57835454939353 223 40.628113 0.9408607363433215 -37.962280345241524 224 39.981403 0.9387840341351679 -38.41155579729448 225 40.47344 0.9410082120740806 -38.517371395420284 226 40.50992 0.905073117126198 -38.330723559230904 227 40.403015 0.9241754339995079 -38.82376843423213 228 40.837246 0.9169308117585409 -38.63620495396037 229 40.664013 0.9386896822080206 -38.90706423092294 230 40.896824 0.9339068463683333 -38.696793116982434 231 40.702633 0.9145780311554074 -38.65284400771682 232 40.62579 0.9006325662796982 -39.637494454352314 233 41.626167 0.9338129447333415 -39.78114728801758 234 41.796886 0.9242801144903465 -39.65097847274954 235 41.57147 0.9338876655063292 -39.77878315731534 236 41.68901 0.9384326516475999 -40.16391482862492 237 42.156246 0.933748701808647 -40.17643420414302 238 42.09112 0.9195355828867677 -40.69882361771501 239 42.6522 0.9433070222502209 -39.88425628693875 240 41.795567 0.9312496313482089 -40.06667512071424 241 41.96842 0.6649959661420343 -40.25920371553695 242 42.185585 0.933840538984267 -39.968183863842874 243 41.943356 0.9434117715961123 -41.71691005353176 244 43.644 0.9361124592580883 -41.13699389861304 245 43.023914 0.9338435982410622 -41.28230195762135 246 43.193844 0.9239637685647102 -41.560060418049325 247 43.493023 0.9289605367593431 -41.59855809952668 248 43.51518 0.9146019879654815 -41.36367037269284 249 43.246273 0.9217172360497119 -41.47216798973369 250 43.336735 0.9120347967280558 -42.23140674910907 251 44.096336 0.919600958481893 -41.837013226281165 252 43.689465 0.9168157005405911 -42.09410495072614 253 43.97186 0.9313324891094498 -43.33269437968738 254 45.19633 0.9103041900176997 -42.74235257628621 255 44.50436 0.9316946553597163 -42.6459909477362 256 44.48472 0.94340699206266 -43.07054047866841 257 44.94371 0.9266275445921005 -43.677742960288704 258 45.51856 0.9314846136123487 -43.152400611873844 259 45.022953 0.9295360345348367 -43.85954580481574 260 45.649456 0.941096952293478 -43.238502375844305 261 45.010784 0.9409037386367854 -44.1959112758573 262 45.964397 0.9077044035477514 -43.272662943668806 263 45.04781 0.9312846816610567 -43.95372238782308 264 45.766148 0.9265385942747207 -43.79405790988835 265 45.61018 0.943503887562348 -43.60608123459338 266 45.40575 0.9482943023231649 -43.73149435438444 267 45.533237 0.9362164782704477 -44.456164408531464 268 46.25102 0.675526897066116 -44.91974124133422 269 46.70814 0.9386585174647938 -45.20236344500999 270 46.91764 0.93619619611082 -44.87311225838185 271 46.6652 0.9146629836113067 -44.77760838980033 272 46.537632 0.9338646365339074 -44.5743114139252 273 46.369617 0.9148357574921085 -45.31227919633451 274 47.010086 0.9288605841947968 -44.80082827942532 275 46.527664 0.9336166790243395 -44.824010647628135 276 46.56868 0.9169747400866328 -45.68804309217637 277 47.382 0.9312534770693748 -45.72352629848309 278 47.445965 0.9458445017987467 -45.68033726958641 279 47.387566 0.9411043389165059 -45.28531904494126 280 47.063156 0.9409118263306525 -46.07666833832236 281 47.77425 0.9435329156867446 -46.02638579461601 282 47.709824 0.9243080172423859 -45.56032259622697 283 47.30043 0.9434939869396793 -46.44801472012902 284 48.195957 0.9386534233274456 -46.71075045235024 285 48.422726 0.9338023676783402 -47.018610640639295 286 48.772133 0.940949919907394 -47.12493735282759 287 48.82328 0.9289707251709508 -47.17421957635267 288 48.87887 0.9433070222502209 -46.8329308939545 289 48.56754 0.9339093412701246 -46.83154870063826 290 48.513634 0.9482811955528008 -47.114446420883745 291 48.80567 0.943448994813494 -46.77148101741365 292 48.43017 0.9127387457555893 -47.812212565238646 293 49.537113 0.9459554529871025 -47.877821717117726 294 49.573853 0.9361937876461894 -48.151367839926486 295 49.842133 0.9292136528480338 -47.717037035109826 296 49.411465 0.9314517570811924 -48.11062758982222 297 49.814972 0.9312899682618766 -47.53576476351081 298 49.222496 0.9146626475861207 -48.04822511991112 299 49.701046 0.9384967847679025 -48.58175595749809 300 50.267567 0.9361233151078371 -49.155470433388636 301 50.83741 0.9530027920396302 -49.069428836087674 302 50.709015 0.917043442144682 -49.15210042853091 303 50.777473 0.9433968498615068 -49.595696934693784 304 51.26062 0.9384554245612082 -49.23064757322961 305 50.889492 0.9264896161164528 -49.762914131546964 306 51.405636 0.9242478273129594 -50.39477308448899 307 52.04346 0.9411936815457121 -48.88187845913276 308 50.469208 0.9268086724349757 -49.52015842460787 309 51.11294 0.931601310868745 -50.341840874423134 310 51.941063 0.9267573710458545 -49.67789876554459 311 51.296425 0.6822948807128411 -50.10123211311376 312 51.71586 0.9362690826160238 -50.61006503534119 313 52.24219 0.9338174702952926 -50.964289175170606 314 52.56595 0.9245980648792232 -49.86471802759799 315 51.452152 0.9315344304664925 -50.532457677194245 316 52.09473 0.9340373735001957 -50.60457752373329 317 52.186676 0.9169286143512682 -50.52824718691812 318 52.117687 0.9532653343308853 -50.96577980050135 319 52.553677 0.938578592064411 -51.0455748103674 320 52.595074 0.9362616050374165 -51.07663928798285 321 52.705574 0.9483645253012802 -51.52745472623567 322 53.113655 0.9289600004461873 -52.10178824000801 323 53.655758 0.9363359401344202 -51.79856635900364 324 53.396515 0.943281922979496 -51.96477465996567 325 53.52005 0.9360412715208521 -51.92645560684186 326 53.539406 0.9483098080886775 -51.38460532306398 327 52.96323 0.9434038428767193 -52.73174277648113 328 54.274914 0.9409543776131868 -52.86808538836149 329 54.414093 0.9241566366137663 -51.96624843494258 330 53.514637 0.9314256239321634 -52.25812006328773 331 53.80943 0.9123081283435847 -53.37812675284057 332 54.911587 0.9385990380860523 -52.37090489707615 333 53.952515 0.9579839794145553 -53.80769842632409 334 55.356075 0.9362773020521817 -53.22752059642615 335 54.766113 0.9196475097324497 -52.70105835588845 336 54.22362 0.933786296164048 -53.88103299918548 337 55.38281 0.9241819503751091 -53.79245257055031 338 55.321426 0.9339725860360755 -53.23907889397464 339 54.719875 0.9361013573651175 -53.55862202247147 340 55.056675 0.9410798597561829 -54.54624748743109 341 56.045376 0.9384642509919361 -54.290690354407715 342 55.84019 0.934013768170251 -54.37227055875233 343 55.82118 0.9411092297361782 -55.13394682197916 344 56.65467 0.9289731936906593 -54.45063631505906 345 55.937954 0.9457481311257834 -55.44111763155826 346 56.93847 0.9170716964756597 -54.008481631069905 347 55.50485 0.948337475801075 -55.85371567990805 348 57.381233 0.9389332438030509 -54.21278343345174 349 55.72363 0.9360277665886446 -56.36200450542284 350 57.926163 0.9385442036382121 -55.287349617799606 351 56.78759 0.9385688256596972 -54.98875789338471 352 56.47116 0.940906861329766 -55.46727402753033 353 56.927868 0.9265359248007349 -55.95690627681954 354 57.44238 0.9315254995164585 -55.90495707323904 355 57.372593 0.9383810801287759 -55.151627522555124 356 56.613277 0.9385497709620318 -56.583374575557535 357 58.040905 0.9289206400998612 -57.15375206458828 358 58.6588 0.938839813813307 -56.911435491045104 359 58.396004 0.9506629710160279 -56.67874534266195 360 58.077644 0.9336939497134896 -57.306086903379736 361 58.794327 0.9239713089943248 -57.246686351934656 362 58.715927 0.9388328058910975 -57.47209943351888 363 58.92101 0.9363103601110734 -56.88259152303607 364 58.360832 0.9409458416111852 -57.866559304079786 365 59.32959 0.9289849894538038 -57.61979434937067 366 59.0752 0.919502506767974 -57.6007705067422 367 59.061634 0.9289135610426666 -57.98020929897457 368 59.40823 0.9507884626324186 -58.311167046903996 369 59.7086 0.9361686893229114 -59.39663606416072 370 60.845245 0.9316117947380224 -59.371685121169286 371 60.808956 0.9336126699285775 -58.3735555935064 372 59.83399 0.9289408034905836 -58.70765114684771 373 60.12649 0.9386334413532126 -58.79655104446919 374 60.199677 0.9411078959568269 -58.37978323756041 375 59.746197 0.9193302706682228 -58.158461247723665 376 59.517582 0.9361657977883846 -58.546475525862455 377 59.88474 0.941169473708928 -59.53714352573999 378 60.936394 0.9387133647247262 -58.89642004203374 379 60.307255 0.9409592803122776 -60.08621223649035 380 61.463406 0.9244480572174005 -60.25164412007419 381 61.627808 0.9385516383034639 -59.65470634874928 382 61.08458 0.9361042516110605 -59.658306001013536 383 61.058113 0.941038654307516 -59.357202616515686 384 60.75208 0.9313708523204801 -59.39215362242612 385 60.795376 0.9335584547105307 -60.92718077517706 386 62.32712 0.9338179989175789 -60.17290510762455 387 61.534187 0.9314299616780156 -60.56431999514331 388 61.925266 0.9458537806337974 -61.94726721471983 389 63.304695 0.9435795099202235 -60.387911320599784 390 61.779427 0.9363378626933384 -60.95682530783498 391 62.308586 0.9410536951463077 -62.038077348808514 392 63.43991 0.9457895324807266 -60.76609456964016 393 62.10059 0.9386046172118506 -61.43661312126889 394 62.799294 0.9338445971869483 -62.194463167186484 395 63.546772 0.9290592059910425 -61.34540248224434 396 62.686554 0.9337772738071038 -61.92923315461959 397 63.255863 0.9170651085588399 -61.94406577986058 398 63.29587 0.9386422848742059 -61.438265160148894 399 62.744526 0.9337356527720342 -63.321828239387976 400 64.6676 0.9171902704737648 -62.18541259205495 401 63.516994 0.9314962637890593 -62.54184636705649 402 63.861813 0.9339469478811487 -62.86847780473726 403 64.18908 0.6772017320486903 -62.909016008837916 404 64.23691 0.9195734347310877 -62.97319920085954 405 64.313446 0.9338988247941281 -62.641605360589296 406 63.92783 0.9339208927778799 -63.31623631940531 407 64.6377 0.9383006907406831 -64.57479770312608 408 65.90879 0.9386483283439363 -63.16528180991833 409 64.48343 0.9386139039436976 -62.484030683391175 410 63.737762 0.9339143685893947 -62.7564247961461 411 64.0631 0.9388207193382838 -63.982445512911205 412 65.2788 0.9384312579252562 -63.552321233206044 413 64.84384 0.941205947278271 -63.340142906776194 414 64.626945 0.9384893554651378 -64.24998175429047 415 65.51272 0.9288866528249509 -64.3691781845215 416 65.62501 0.9316821061507065 -64.23243505147533 417 65.521965 0.9507055905467419 -63.84915206111708 418 65.11914 0.9458303169073627 -63.92649566963884 419 65.23108 0.9579887370969611 -66.03297231283726 420 67.35751 0.9531433760568697 -66.27176425240023 421 67.49118 0.9530813412583439 -63.987338741355515 422 65.2738 0.9481835336088783 -65.2125453335299 423 66.529686 0.9386329780482606 -65.90432165891632 424 67.151344 0.9384484519870665 -65.03557589601297 425 66.26163 0.9384489166097247 -64.8865337249148 426 66.15624 0.9198374684364394 -65.86238122595837 427 67.0713 0.9338455961026672 -66.14336286195864 428 67.421875 0.9410978417221632 -65.36567428237997 429 66.59523 0.9435714166509804 -66.32997067481278 430 67.63621 0.9147511319719784 -66.19433843608725 431 67.41604 0.9197764559456639 -65.8439740410718 432 67.09952 0.929545391496548 -66.98199880744104 433 68.1913 0.9337979244654195 -65.96527017957828 434 67.244026 0.941074697720604 -66.37710333669884 435 67.602905 0.9217525125973284 -67.74361817203979 436 68.90475 0.9385762734603942 -67.55862885241712 437 68.80382 0.9387161691356409 -67.0004234336046 438 68.139656 0.9292589078596195 -67.44964640283584 439 68.61462 0.9337436779006592 -66.17002008030208 440 67.42282 0.9411585358768727 -67.22718152185865 441 68.45496 0.9313694121957602 -67.01996477501709 442 68.25402 0.9336688575468722 -67.91837407715634 443 69.15321 0.9338505902288008 -68.20088802209115 444 69.450165 0.9337341519649334 -67.30721314209967 445 68.487305 0.9288366592443782 -67.53832828384122 446 68.792366 0.9506415990524855 -68.7820096565673 447 69.98891 0.9191909857330078 -68.40369372062699 448 69.56446 0.9290320799251226 -67.26062333777375 449 68.49412 0.9267218492568214 -67.63568832965665 450 68.84266 0.9386529601820951 -67.41922834747929 451 68.58336 0.9193893585415871 -68.22671761860103 452 69.4336 0.9314536328919572 -68.4287177588607 453 69.59918 0.9216265158801866 -69.34136190841976 454 70.518105 0.9147348937978166 -69.84342882250567 455 70.950745 0.9314601664680866 -69.77049582961781 456 70.966354 0.9336538346310982 -70.54906453282715 457 71.70278 0.9219217688476227 -68.46819294255194 458 69.641235 0.9266098440848324 -69.88995761234622 459 71.05608 0.9337055134183398 -69.12690967394182 460 70.289055 0.9191032692161769 -69.4177785810578 461 70.609924 0.9482770683557084 -69.49131457860922 462 70.644104 0.9386152941331335 -69.31184194773087 463 70.49469 0.9458380228163248 -71.02913305874603 464 72.202065 0.9217244650362648 -71.00399700979757 465 72.147125 0.9241146740231106 -69.91240351498041 466 71.03903 0.9434610327742449 -69.89451136568607 467 71.04827 0.9360301730415713 -70.36350021174356 468 71.52441 0.9240795781734016 -70.6523114951764 469 71.841255 0.9433706255170797 -70.97800925196775 470 72.14248 0.936518060472732 -70.34140671796908 471 71.41308 0.6700125883845816 -71.37415581564974 472 72.48785 0.9432194857506386 -71.06887074495054 473 72.24132 0.6799191328717841 -72.6891986923582 474 73.81839 0.9409304272378429 -70.86523599575055 475 72.000565 0.9312640428367959 -71.71540876267521 476 72.87248 0.9409866763388373 -72.24009688646504 477 73.38629 0.9436687271547698 -71.5825211778181 478 72.69054 0.926685397593051 -72.76757627835993 479 73.88097 0.9265821387973513 -71.5544621509566 480 72.648766 0.9288098558361103 -71.09213235540517 481 72.20597 0.9458307258727466 -72.04496405793456 482 73.15429 0.9312569899960574 -72.789000667466 483 73.91122 0.9244934086446598 -72.70620548973753 484 73.79589 0.9291176658806618 -72.65343210596274 485 73.81641 0.9315298218400299 -73.14929774965174 486 74.2042 0.9265161649294796 -72.95100028294499 487 74.01866 0.9434752885678902 -74.44905302395495 488 75.49894 0.9433805336820021 -74.15705425672948 489 75.20378 0.9313477380662144 -73.06826244177253 490 74.16461 0.9243918675971277 -74.00443576762638 491 75.08786 0.9361491425082101 -73.25847584660967 492 74.355576 0.9361365932044579 -73.64635874598994 493 74.724525 0.9145690577365322 -73.44005104561197 494 74.54677 0.924213343391225 -73.47182932922905 495 74.57868 0.9217783441741852 -74.35021747514558 496 75.396805 0.9362565334543848 -74.50758929101985 497 75.537766 0.9433921498602896 -75.04858182848967 498 76.152336 0.9362876658572992 -74.06536623563017 499 75.100006 0.9557846980030433 -75.07351357248896
import matplotlib as mpl
plt.figure(figsize=(10,10))
ax = plt.axes()
for n in set(gmm.predict(latent)):
covariances = gmm.covariances_[n][:2, :2]
v, w = np.linalg.eigh(covariances)
u = w[0] / np.linalg.norm(w[0])
angle = np.arctan2(u[1], u[0])
angle = 180 * angle / np.pi # convert to degrees
v = 2. * np.sqrt(2.) * np.sqrt(v)
ell = mpl.patches.Ellipse(gmm.means_[n, :2], v[0], v[1],
180 + angle, color='k')
ell.set_alpha(0.2)
ax.add_patch(ell)
pylab.scatter(latent[:,0],latent[:,1],c=gmm.predict(latent))
<matplotlib.collections.PathCollection at 0x125f6dfd0>
plt.figure(figsize=(10,10))
pylab.scatter(VAE_predict_mean.predict(data)[:,0],VAE_predict_mean.predict(data)[:,1],c=label)
<matplotlib.collections.PathCollection at 0x125f2a358>
gmm.predict(latent)
array([4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 1, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2, 2, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 3, 3, 3, 3, 3, 3, 3, 3, 4, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4])
import pandas as pd
result = gmm.predict(latent)
df = pd.DataFrame({ 'result' : gmm.predict(latent),'label' : label})
iremono = np.zeros((5,5))
for i in range(5):
acc = df[df['label']==i]
for j in range(5):
iremono[i,j] = len(acc[acc['result']==j])
acc = iremono.max(1).sum()/len(label)
print(acc)
0.982