Let's set some setting for this Jupyter Notebook.
%matplotlib inline
from warnings import filterwarnings
filterwarnings("ignore")
import os
os.environ['MKL_THREADING_LAYER'] = 'GNU'
os.environ['THEANO_FLAGS'] = 'device=cpu'
import numpy as np
import pandas as pd
import pymc3 as pm
import seaborn as sns
import matplotlib.pyplot as plt
np.random.seed(12345)
rc = {'xtick.labelsize': 20, 'ytick.labelsize': 20, 'axes.labelsize': 20, 'font.size': 20,
'legend.fontsize': 12.0, 'axes.titlesize': 10, "figure.figsize": [12, 6]}
sns.set(rc = rc)
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"
Now, let's import the MLPClassifier
algorithm from the pymc-learn
package.
import pmlearn
from pmlearn.neural_network import MLPClassifier
print('Running on pymc-learn v{}'.format(pmlearn.__version__))
Running on pymc-learn v0.0.1.rc0
Generate synthetic data.
from sklearn.datasets import make_moons
from sklearn.preprocessing import scale
import theano
floatX = theano.config.floatX
X, y = make_moons(noise=0.2, random_state=0, n_samples=1000)
X = scale(X)
X = X.astype(floatX)
y = y.astype(floatX)
## Plot the data
fig, ax = plt.subplots(figsize=(12, 8))
ax.scatter(X[y==0, 0], X[y==0, 1], label='Class 0')
ax.scatter(X[y==1, 0], X[y==1, 1], color='r', label='Class 1')
sns.despine(); ax.legend()
ax.set(xlabel='X', ylabel='y', title='Toy binary classification data set');
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3)
model = MLPClassifier()
model.fit(X_train, y_train)
Average Loss = 140.51: 100%|██████████| 200000/200000 [02:46<00:00, 1203.69it/s] Finished [100%]: Average Loss = 140.52
MLPClassifier(n_hidden=5)
model.plot_elbo()
pm.traceplot(model.trace);
pm.forestplot(model.trace, varnames=["w_in_1"]);
pm.summary(model.trace)
mean | sd | mc_error | hpd_2.5 | hpd_97.5 | |
---|---|---|---|---|---|
w_in_1__0_0 | 0.005568 | 0.853588 | 0.008590 | -1.690336 | 1.657664 |
w_in_1__0_1 | 2.625947 | 0.430473 | 0.004169 | 1.754085 | 3.449592 |
w_in_1__0_2 | -0.029436 | 0.857784 | 0.007223 | -1.654908 | 1.703846 |
w_in_1__0_3 | -2.395630 | 0.181738 | 0.001690 | -2.751851 | -2.040889 |
w_in_1__0_4 | 0.622473 | 0.020184 | 0.000191 | 0.584093 | 0.663157 |
w_in_1__1_0 | 0.004472 | 0.913144 | 0.009296 | -1.776564 | 1.778230 |
w_in_1__1_1 | -0.481947 | 0.301163 | 0.003205 | -1.086732 | 0.082603 |
w_in_1__1_2 | 0.005991 | 0.899893 | 0.007878 | -1.798096 | 1.727599 |
w_in_1__1_3 | -0.643017 | 0.116311 | 0.001178 | -0.872403 | -0.415040 |
w_in_1__1_4 | -0.219521 | 0.024275 | 0.000220 | -0.266726 | -0.171858 |
w_1_2__0_0 | 0.047623 | 1.073161 | 0.011119 | -2.021469 | 2.151871 |
w_1_2__0_1 | -0.011368 | 1.077858 | 0.009985 | -2.194813 | 2.021219 |
w_1_2__0_2 | -0.031757 | 1.054057 | 0.009998 | -2.083095 | 2.018043 |
w_1_2__0_3 | 0.020650 | 1.072316 | 0.010139 | -2.107038 | 2.109442 |
w_1_2__0_4 | -0.004587 | 0.089081 | 0.000801 | -0.180803 | 0.168265 |
w_1_2__1_0 | -0.046195 | 0.995340 | 0.010568 | -2.014609 | 1.858622 |
w_1_2__1_1 | -0.037873 | 1.019444 | 0.009749 | -2.040949 | 1.960829 |
w_1_2__1_2 | -0.021888 | 1.032725 | 0.011500 | -2.130446 | 1.901217 |
w_1_2__1_3 | 0.017286 | 1.002461 | 0.009124 | -2.031078 | 1.903586 |
w_1_2__1_4 | 1.081594 | 0.051777 | 0.000545 | 0.979248 | 1.182610 |
w_1_2__2_0 | 0.035723 | 1.060417 | 0.009704 | -2.036665 | 2.086344 |
w_1_2__2_1 | 0.026448 | 1.068230 | 0.010686 | -2.102493 | 2.071060 |
w_1_2__2_2 | -0.029435 | 1.035190 | 0.010280 | -1.957489 | 2.077871 |
w_1_2__2_3 | -0.000834 | 1.059046 | 0.011346 | -2.108175 | 1.998687 |
w_1_2__2_4 | -0.007699 | 0.090879 | 0.000956 | -0.183102 | 0.169459 |
w_1_2__3_0 | -0.032409 | 1.056289 | 0.010256 | -2.207693 | 1.981262 |
w_1_2__3_1 | 0.034940 | 1.027592 | 0.009972 | -2.005263 | 2.052578 |
w_1_2__3_2 | 0.008883 | 1.029169 | 0.010479 | -2.062519 | 1.946606 |
w_1_2__3_3 | -0.008496 | 1.036160 | 0.010453 | -2.109220 | 1.909782 |
w_1_2__3_4 | -1.304680 | 0.057650 | 0.000542 | -1.413446 | -1.188361 |
w_1_2__4_0 | -0.011010 | 1.082071 | 0.010123 | -2.120924 | 2.138763 |
w_1_2__4_1 | -0.034882 | 1.062334 | 0.009535 | -2.165315 | 2.008977 |
w_1_2__4_2 | -0.003337 | 1.065358 | 0.010463 | -2.068589 | 2.064908 |
w_1_2__4_3 | 0.045525 | 1.057940 | 0.010719 | -2.045978 | 2.035685 |
w_1_2__4_4 | -4.385587 | 0.094443 | 0.000898 | -4.570870 | -4.199793 |
w_2_out__0 | -0.010038 | 0.286819 | 0.002983 | -0.563166 | 0.559095 |
w_2_out__1 | -0.013392 | 0.286911 | 0.002836 | -0.572484 | 0.545143 |
w_2_out__2 | -0.013202 | 0.291041 | 0.003212 | -0.586738 | 0.550436 |
w_2_out__3 | 0.013324 | 0.289007 | 0.003026 | -0.581552 | 0.551274 |
w_2_out__4 | -5.422810 | 0.401398 | 0.003642 | -6.218766 | -4.651419 |
pm.plot_posterior(model.trace, varnames=["w_in_1"],
figsize = [14, 8]);
y_pred = model.predict(X_test)
100%|██████████| 2000/2000 [00:00<00:00, 2578.43it/s]
y_pred
array([False, True, True, False, False, True, True, True, True, False, False, True, True, False, True, True, True, False, False, True, True, True, True, True, False, True, False, False, False, False, False, True, False, True, True, False, True, True, True, False, False, False, True, True, False, False, False, True, False, False, False, True, True, False, True, True, True, False, True, False, True, False, True, False, False, True, True, False, False, True, False, False, False, True, True, True, True, True, True, False, False, False, False, False, True, True, True, False, False, False, True, True, True, True, False, True, False, True, True, False, True, True, True, True, False, False, False, True, False, False, False, False, True, False, False, True, False, True, True, False, False, True, True, True, False, True, True, True, False, True, False, True, True, False, False, True, True, True, True, False, True, True, False, False, True, False, True, False, True, True, False, True, False, True, True, True, False, False, False, False, True, False, True, False, True, False, False, False, False, True, True, False, True, False, True, False, True, False, True, True, True, True, False, True, True, True, False, False, True, False, True, True, False, True, False, True, False, False, True, False, False, True, True, False, True, True, True, False, True, True, True, True, True, False, True, False, True, True, False, True, False, False, True, True, False, False, True, True, True, False, False, True, True, True, False, True, False, False, True, False, False, True, False, False, True, True, False, False, True, True, False, True, True, False, False, False, True, False, False, False, False, False, False, False, True, False, True, False, False, False, True, False, True, False, False, True, False, False, True, False, True, True, True, False, True, True, True, True, True, False, False, False, True, False, True, False, False, False, False, False], dtype=bool)
model.score(X_test, y_test)
100%|██████████| 2000/2000 [00:00<00:00, 2722.33it/s]
0.95999999999999996
model.save('pickle_jar/mlpc')
model_new = MLPClassifier()
model_new.load('pickle_jar/mlpc')
model_new.score(X_test, y_test)
100%|██████████| 2000/2000 [00:00<00:00, 2619.50it/s]
0.95999999999999996
model2 = MLPClassifier()
model2.fit(X_train, y_train, inference_type='nuts')
Multiprocess sampling (4 chains in 4 jobs) NUTS: [w_2_out, w_1_2, w_in_1] 100%|██████████| 2500/2500 [04:25<00:00, 9.42it/s] There were 125 divergences after tuning. Increase `target_accept` or reparameterize. There were 228 divergences after tuning. Increase `target_accept` or reparameterize. There were 210 divergences after tuning. Increase `target_accept` or reparameterize. There were 32 divergences after tuning. Increase `target_accept` or reparameterize. The estimated number of effective samples is smaller than 200 for some parameters.
MLPClassifier(n_hidden=5)
pm.traceplot(model2.trace, varnames=["w_in_1"]);
pm.gelman_rubin(model2.trace)
{'w_in_1': array([[ 1.03392059, 1.00889386, 1.0116798 , 1.00952281, 1.00310832], [ 1.00180089, 1.00662138, 1.00244567, 1.00753298, 1.00338383]]), 'w_1_2': array([[ 0.99999921, 1.00373929, 1.00043873, 1.00022153, 1.00150073], [ 1.00202154, 1.00028483, 1.00173403, 1.00384901, 1.00022611], [ 1.00035073, 1.00026924, 1.00524066, 1.00006522, 1.00168698], [ 1.00206691, 1.00377702, 1.00243599, 1.00069978, 1.00472955], [ 0.99978974, 0.99992665, 1.00151647, 1.00214903, 1.00018014]]), 'w_2_out': array([ 1.01048089, 1.0018095 , 1.00558228, 1.00216195, 1.00162127])}
pm.energyplot(model2.trace);
pm.forestplot(model2.trace, varnames=["w_in_1"]);
pm.summary(model2.trace)
mean | sd | mc_error | hpd_2.5 | hpd_97.5 | n_eff | Rhat | |
---|---|---|---|---|---|---|---|
w_in_1__0_0 | 0.225135 | 1.453129 | 0.090766 | -2.600963 | 2.967663 | 165.262816 | 1.033921 |
w_in_1__0_1 | 0.040237 | 1.432005 | 0.092519 | -2.821038 | 2.733033 | 156.605213 | 1.008894 |
w_in_1__0_2 | 0.018382 | 1.318325 | 0.084790 | -2.654008 | 2.752361 | 152.086243 | 1.011680 |
w_in_1__0_3 | 0.059441 | 1.520335 | 0.099235 | -2.778439 | 2.907296 | 171.541523 | 1.009523 |
w_in_1__0_4 | -0.105049 | 1.467413 | 0.106934 | -2.821507 | 2.759036 | 142.862990 | 1.003108 |
w_in_1__1_0 | 0.038815 | 0.617805 | 0.021710 | -1.325060 | 1.497353 | 811.520570 | 1.001801 |
w_in_1__1_1 | 0.048561 | 0.651136 | 0.033623 | -1.387215 | 1.528355 | 315.526252 | 1.006621 |
w_in_1__1_2 | -0.040393 | 0.630075 | 0.029703 | -1.522670 | 1.369573 | 350.442761 | 1.002446 |
w_in_1__1_3 | -0.006621 | 0.615670 | 0.023998 | -1.488595 | 1.321337 | 588.732171 | 1.007533 |
w_in_1__1_4 | -0.022356 | 0.602055 | 0.023388 | -1.399013 | 1.409409 | 630.152499 | 1.003384 |
w_1_2__0_0 | -0.030117 | 1.222961 | 0.032388 | -2.222790 | 2.459064 | 1465.347411 | 0.999999 |
w_1_2__0_1 | -0.033481 | 1.257416 | 0.036231 | -2.528749 | 2.320204 | 1122.000446 | 1.003739 |
w_1_2__0_2 | 0.058013 | 1.255037 | 0.032539 | -2.337473 | 2.535129 | 1445.368417 | 1.000439 |
w_1_2__0_3 | 0.012622 | 1.216279 | 0.031562 | -2.336343 | 2.350345 | 1548.011815 | 1.000222 |
w_1_2__0_4 | -0.022592 | 1.238055 | 0.036278 | -2.382896 | 2.395463 | 1216.492733 | 1.001501 |
w_1_2__1_0 | 0.004957 | 1.255870 | 0.033091 | -2.390911 | 2.475214 | 1327.119255 | 1.002022 |
w_1_2__1_1 | 0.045525 | 1.240332 | 0.026389 | -2.355288 | 2.492054 | 1738.691574 | 1.000285 |
w_1_2__1_2 | -0.080770 | 1.252471 | 0.034783 | -2.602587 | 2.212237 | 1506.411171 | 1.001734 |
w_1_2__1_3 | 0.011486 | 1.249588 | 0.030010 | -2.432115 | 2.427352 | 1752.220349 | 1.003849 |
w_1_2__1_4 | 0.019531 | 1.218963 | 0.030665 | -2.313649 | 2.383177 | 1517.420895 | 1.000226 |
w_1_2__2_0 | -0.022256 | 1.261408 | 0.033143 | -2.522347 | 2.358721 | 1228.355488 | 1.000351 |
w_1_2__2_1 | -0.016605 | 1.260692 | 0.030260 | -2.431030 | 2.443035 | 1762.632045 | 1.000269 |
w_1_2__2_2 | -0.039904 | 1.277497 | 0.036323 | -2.547160 | 2.316859 | 1293.686512 | 1.005241 |
w_1_2__2_3 | -0.007594 | 1.257849 | 0.034130 | -2.638556 | 2.284691 | 1585.661623 | 1.000065 |
w_1_2__2_4 | 0.024879 | 1.207795 | 0.029777 | -2.424846 | 2.292940 | 1751.025064 | 1.001687 |
w_1_2__3_0 | 0.005390 | 1.242961 | 0.034378 | -2.474162 | 2.308738 | 1271.123642 | 1.002067 |
w_1_2__3_1 | 0.039325 | 1.312166 | 0.049990 | -2.564783 | 2.642687 | 449.518792 | 1.003777 |
w_1_2__3_2 | -0.021349 | 1.300328 | 0.050331 | -2.933611 | 2.246004 | 431.211084 | 1.002436 |
w_1_2__3_3 | 0.019744 | 1.245684 | 0.027273 | -2.478284 | 2.358122 | 1915.759049 | 1.000700 |
w_1_2__3_4 | -0.010321 | 1.256208 | 0.036211 | -2.341684 | 2.450475 | 1107.995158 | 1.004730 |
w_1_2__4_0 | -0.058451 | 1.235791 | 0.027882 | -2.581343 | 2.202579 | 1695.014178 | 0.999790 |
w_1_2__4_1 | 0.012964 | 1.266865 | 0.035280 | -2.335460 | 2.498519 | 1313.828194 | 0.999927 |
w_1_2__4_2 | -0.020461 | 1.286353 | 0.034562 | -2.579003 | 2.365088 | 1301.591220 | 1.001516 |
w_1_2__4_3 | 0.009367 | 1.227171 | 0.026680 | -2.364722 | 2.360280 | 1782.650186 | 1.002149 |
w_1_2__4_4 | -0.022705 | 1.233417 | 0.031601 | -2.367693 | 2.367387 | 1551.870684 | 1.000180 |
w_2_out__0 | -0.029666 | 2.331660 | 0.076394 | -4.488181 | 4.396616 | 764.040782 | 1.010481 |
w_2_out__1 | -0.037847 | 2.354610 | 0.085388 | -4.587552 | 4.386218 | 720.542202 | 1.001810 |
w_2_out__2 | 0.127920 | 2.375094 | 0.092250 | -4.319635 | 4.449930 | 523.489082 | 1.005582 |
w_2_out__3 | 0.030048 | 2.295993 | 0.076132 | -4.349295 | 4.310776 | 840.771489 | 1.002162 |
w_2_out__4 | 0.002714 | 2.244351 | 0.078376 | -4.295934 | 4.466476 | 769.253983 | 1.001621 |
pm.plot_posterior(model2.trace, varnames=["w_in_1"],
figsize = [14, 8]);
y_pred2 = model2.predict(X_test)
100%|██████████| 2000/2000 [00:00<00:00, 2569.74it/s]
y_pred2
array([False, True, True, False, False, True, True, True, True, False, False, True, True, False, True, True, True, False, False, True, True, True, True, True, False, True, False, False, False, False, False, True, False, True, True, False, True, True, True, False, False, False, True, True, False, False, False, True, False, False, False, True, True, False, True, True, True, False, True, False, True, False, True, False, False, True, True, False, False, True, False, False, False, True, True, True, True, True, True, False, False, False, False, False, True, True, True, False, False, False, True, True, True, True, False, True, False, True, True, False, True, True, True, True, False, False, False, True, False, False, False, False, True, False, False, True, False, True, True, False, False, True, True, True, False, True, True, True, False, True, False, True, True, False, False, True, True, True, True, False, True, True, False, False, True, False, True, False, True, True, False, True, False, True, True, True, False, False, False, False, True, False, True, False, True, False, False, False, False, True, True, False, True, False, True, False, True, False, True, True, True, True, False, True, True, True, False, False, True, False, True, True, False, True, False, True, False, False, True, False, False, True, True, False, True, True, True, False, True, True, True, True, True, False, True, False, True, True, False, True, False, False, True, True, False, False, True, True, True, False, False, True, True, True, False, True, False, False, True, False, False, True, False, False, True, True, False, False, True, True, False, True, True, False, False, False, True, False, False, False, False, False, False, False, True, False, True, False, False, False, True, False, True, False, False, True, False, False, True, False, True, True, True, False, True, True, True, True, True, False, False, False, True, False, True, False, False, False, False, False], dtype=bool)
model2.score(X_test, y_test)
100%|██████████| 2000/2000 [00:00<00:00, 2645.86it/s]
0.96333333333333337
model2.save('pickle_jar/mlpc2')
model2_new = MLPClassifier()
model2_new.load('pickle_jar/mlpc2')
model2_new.score(X_test, y_test)
100%|██████████| 2000/2000 [00:00<00:00, 2550.55it/s]
0.95999999999999996