from sklearn.datasets import load_boston
from sklearn.linear_model import LinearRegression
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import re
import requests
from pprint import pprint
import lightgbm
print(np.__version__)
print(pd.__version__)
1.19.5 1.1.5
x_1 = [np.random.rand() + 2 for x in range(30)]
x_2 = [np.random.rand() + 3 for x in range(30)]
x_1_mean = np.array(x_1).mean()
print("x_1_mean: ", x_1_mean)
x_2_mean = np.array(x_2).mean()
print("x_2_mean: ", x_2_mean)
x_3 = x_1 - x_1_mean
x_4 = x_2 - x_2_mean
x_1.extend(x_3.tolist())
x_2.extend(x_4.tolist())
print(x_1)
plt.scatter(x_1, x_2)
x_1_mean: 2.4788976136540706 x_2_mean: 3.575013308992982 [2.739137061993336, 2.297726638350044, 2.0465701504695466, 2.5123187144322308, 2.0142383981079313, 2.0999984969854766, 2.3025907206822316, 2.1269114517080814, 2.9578184107081, 2.637733881417911, 2.0509559274116818, 2.9246823363324674, 2.0648582859454243, 2.707127302741993, 2.211676899817596, 2.6667216111796037, 2.9605484280248695, 2.443463223026553, 2.8583839139399596, 2.668642439399084, 2.928738778262046, 2.433515229091165, 2.1667413669575817, 2.450858534195633, 2.9219576323221776, 2.839891632521834, 2.023154121011913, 2.7020891166705274, 2.0630896902065943, 2.5447880157085403, 0.2602394483392656, -0.1811709753040267, -0.432327463184524, 0.033421100778160184, -0.4646592155461393, -0.378899116668594, -0.176306892971839, -0.3519861619459892, 0.47892079705402946, 0.1588362677638404, -0.4279416862423888, 0.44578472267839686, -0.41403932770864627, 0.22822968908792252, -0.2672207138364744, 0.18782399752553314, 0.48165081437079893, -0.03543439062751741, 0.37948630028588903, 0.1897448257450134, 0.4498411646079754, -0.04538238456290555, -0.3121562466964889, -0.028039079458437755, 0.44306001866810707, 0.3609940188677636, -0.45574349264215774, 0.22319150301645685, -0.4158079234474763, 0.06589040205446972]
<matplotlib.collections.PathCollection at 0x7efdab3e8750>
!pip install mglearn
Collecting mglearn Downloading https://files.pythonhosted.org/packages/65/38/8aced26fce0b2ae82c3c87cd3b6105f38ca6d9d51704ecc44aa54473e6b9/mglearn-0.1.9.tar.gz (540kB) |████████████████████████████████| 542kB 5.7MB/s Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from mglearn) (1.19.5) Requirement already satisfied: matplotlib in /usr/local/lib/python3.7/dist-packages (from mglearn) (3.2.2) Requirement already satisfied: scikit-learn in /usr/local/lib/python3.7/dist-packages (from mglearn) (0.22.2.post1) Requirement already satisfied: pandas in /usr/local/lib/python3.7/dist-packages (from mglearn) (1.1.5) Requirement already satisfied: pillow in /usr/local/lib/python3.7/dist-packages (from mglearn) (7.1.2) Requirement already satisfied: cycler in /usr/local/lib/python3.7/dist-packages (from mglearn) (0.10.0) Requirement already satisfied: imageio in /usr/local/lib/python3.7/dist-packages (from mglearn) (2.4.1) Requirement already satisfied: joblib in /usr/local/lib/python3.7/dist-packages (from mglearn) (1.0.1) Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->mglearn) (1.3.1) Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->mglearn) (2.8.1) Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->mglearn) (2.4.7) Requirement already satisfied: scipy>=0.17.0 in /usr/local/lib/python3.7/dist-packages (from scikit-learn->mglearn) (1.4.1) Requirement already satisfied: pytz>=2017.2 in /usr/local/lib/python3.7/dist-packages (from pandas->mglearn) (2018.9) Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from cycler->mglearn) (1.15.0) Building wheels for collected packages: mglearn Building wheel for mglearn (setup.py) ... done Created wheel for mglearn: filename=mglearn-0.1.9-py2.py3-none-any.whl size=582638 sha256=4ce4abb8ea8b78cfa0eef820d1d4fced83dae298521ffa0df25d8f3f9d16a433 Stored in directory: /root/.cache/pip/wheels/eb/a6/ea/a6a3716233fa62fc561259b5cb1e28f79e9ff3592c0adac5f0 Successfully built mglearn Installing collected packages: mglearn Successfully installed mglearn-0.1.9
import mglearn
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_moons
moons = make_moons(n_samples=200, noise=0.1, random_state=0)
moons
(array([[ 7.92357355e-01, 5.02648573e-01], [ 1.63158315e+00, -4.63896705e-01], [-6.71092674e-02, 2.67767057e-01], [-1.04412427e+00, -1.82607610e-01], [ 1.76704822e+00, -1.98609868e-01], [ 1.90607398e+00, -7.10915927e-02], [ 9.62192129e-01, 2.61986075e-01], [ 8.86813848e-01, -4.84896235e-01], [ 8.68935196e-01, 3.61092776e-01], [ 1.15352953e+00, -5.72352929e-01], [-3.70714493e-01, 7.21655833e-01], [ 1.95523229e-01, -2.64439358e-01], [ 1.80949658e+00, -1.94707885e-01], [ 1.29046617e+00, -3.23893778e-01], [-8.23662620e-01, 7.17643700e-01], [ 9.85881879e-01, 1.96706218e-01], [ 5.81943503e-01, 9.06311048e-01], [ 2.76118902e-01, -2.87451069e-01], [ 3.17284813e-01, 8.54200996e-01], [-8.52697952e-01, 9.32383274e-01], [ 1.97796688e+00, 1.04796611e-02], [ 8.64249290e-01, -4.63242870e-01], [ 2.57499053e-01, 1.00362573e+00], [ 1.28213982e+00, -2.86519496e-01], [-2.29884751e-01, 1.10576822e+00], [ 2.29022777e-01, 1.42127683e-02], [ 1.89450511e+00, -1.02565542e-01], [ 1.86521084e+00, 1.96753138e-02], [ 9.76884442e-02, -2.94749068e-02], [ 1.03579202e-01, 8.07972963e-01], [ 9.07891105e-01, 5.59120905e-01], [-2.82483560e-01, 6.97029880e-01], [ 9.49718195e-01, -4.29281289e-01], [ 2.44475950e-02, 2.31178102e-01], [ 7.62407790e-01, 7.01679032e-01], [ 1.87210764e+00, 2.35600223e-01], [ 4.77778324e-01, -1.64382571e-01], [ 9.31783953e-01, 7.18760735e-02], [-8.12884298e-01, 5.62089952e-01], [-1.00520615e-01, 4.03688699e-01], [ 5.82424127e-01, -4.68412597e-01], [-3.77657209e-01, 8.52782368e-01], [ 6.28019791e-01, 6.90115928e-01], [ 1.21605162e-01, 3.53588146e-02], [ 1.54344475e+00, -3.47674056e-01], [-3.30952410e-01, 1.00049623e+00], [ 1.06619086e+00, 1.33052899e-01], [-7.14739658e-01, 6.19454519e-01], [ 1.63968968e-01, -1.74798733e-01], [ 9.33487565e-01, -4.40111162e-01], [-7.63797985e-01, 8.26528213e-01], [ 3.79557530e-01, -2.02301374e-01], [ 2.07209923e+00, 3.52001635e-02], [-8.11609003e-01, 5.10548387e-02], [ 1.10617390e+00, -5.00532861e-01], [ 7.83108552e-01, 6.23681079e-01], [ 5.19314634e-01, 6.58138927e-01], [ 1.15369377e+00, -5.06557246e-01], [ 2.01456988e-01, 9.30021747e-01], [-2.10315426e-01, 1.00700943e+00], [ 2.56648130e-01, -2.48738776e-01], [-1.03710314e+00, 2.30748836e-01], [ 1.72148759e+00, -2.08345383e-01], [ 9.24799914e-01, 4.94247065e-01], [ 1.08911468e+00, -4.40836789e-01], [-1.16502510e+00, 2.52779074e-01], [-1.75082328e-01, 1.04012819e+00], [ 1.51596993e+00, -3.52163413e-01], [ 1.08086088e-01, 9.56975505e-01], [-9.57080149e-01, 5.12567989e-01], [ 6.47834107e-01, -4.86523627e-01], [-1.02340159e+00, 5.31509316e-01], [ 3.13142058e-01, -5.09754017e-01], [-2.64515467e-01, 5.80672612e-01], [ 8.04988122e-01, -3.71624596e-01], [-9.46347813e-01, 4.42308479e-01], [ 6.54970714e-02, 1.53680672e-01], [ 9.50046139e-01, 3.89984583e-01], [ 5.58939215e-01, 8.09836879e-01], [ 7.25701640e-01, -2.80462884e-01], [ 2.03083726e+00, 1.20170801e-01], [-1.51916239e-01, 9.58192690e-01], [ 6.33253492e-01, -5.64187401e-01], [ 2.02082877e+00, 5.18545302e-01], [ 1.42594806e-02, 2.31861328e-01], [-3.31332409e-01, 1.13198237e+00], [ 9.10306386e-01, 2.01247355e-02], [-3.81510263e-01, 9.71275515e-01], [ 6.14424862e-01, -3.85865038e-01], [ 1.98240620e+00, 3.50624890e-01], [ 4.57707908e-01, 8.94353953e-01], [ 1.03913732e+00, 3.89233822e-01], [ 1.98947025e+00, 1.83673590e-01], [-6.70964946e-01, 7.93803059e-01], [ 2.01955021e+00, 3.25289242e-01], [ 1.77991353e+00, -1.11649028e-01], [ 1.74859902e+00, -2.65552069e-01], [ 4.91433723e-01, -2.73866374e-01], [-1.06817822e+00, 2.97496839e-01], [ 5.66538453e-02, 2.68583098e-02], [ 1.33536908e+00, -6.97952991e-01], [ 1.98056273e+00, 3.68945922e-01], [ 8.04337273e-02, 1.15888076e+00], [ 1.06172312e+00, -4.77816364e-02], [-8.13810949e-01, 4.55071382e-01], [ 1.11442338e-01, 3.94660905e-01], [-9.84200490e-01, 8.13605456e-02], [-5.03822097e-01, 1.02687047e+00], [ 1.56581318e+00, -3.50444504e-01], [ 2.00175240e-01, 8.29137530e-01], [-5.89576838e-01, 7.91803173e-01], [ 5.91096314e-02, 1.04656810e+00], [-8.64462243e-02, 1.14115330e+00], [-8.03546811e-01, 3.68143826e-01], [ 9.36602233e-01, 3.52975983e-01], [ 1.50812484e+00, -5.68831767e-01], [ 4.67190679e-01, 6.57821620e-01], [ 1.97192261e+00, 1.74641987e-01], [ 6.03754895e-01, -4.01384358e-01], [ 6.81080119e-01, 4.27058398e-01], [ 8.53498280e-01, 6.23877292e-01], [-6.34861812e-01, 5.27478557e-01], [-6.96765462e-02, 2.99484737e-01], [ 1.95489528e-02, 9.94207504e-01], [-1.26266935e-01, 3.43883486e-01], [-8.39199835e-01, 7.00314293e-01], [ 8.46496761e-01, -1.16350315e-02], [ 2.29736915e-01, 7.69345906e-02], [ 1.25462994e+00, -3.92336633e-01], [ 1.96824632e+00, 5.49696837e-01], [ 1.02016913e+00, 3.53917835e-01], [-5.31164727e-01, 6.92582395e-01], [-1.10496675e+00, 4.86301545e-01], [ 2.54223306e-01, -1.89474183e-01], [ 1.25509219e+00, -3.32432682e-01], [ 1.57148223e+00, -8.26665774e-02], [ 1.11207737e+00, -3.29369368e-01], [-1.05050177e+00, 1.81123803e-01], [ 2.03731913e+00, 5.06268525e-01], [ 8.37104705e-01, 4.38321722e-01], [ 2.70944977e-02, 4.29302511e-02], [-2.68526403e-02, 5.04722011e-01], [ 4.47263472e-01, 9.95926194e-01], [-4.78109868e-01, 9.94572554e-01], [-4.04629660e-01, 8.86029986e-01], [ 1.03118702e+00, 1.02021446e-01], [ 2.67391648e-01, 1.15977827e-01], [ 1.31563032e+00, -4.68055845e-01], [ 2.72709481e-01, 1.06929344e+00], [-2.42902398e-02, 1.56393942e-01], [ 4.77795869e-01, -3.90843212e-01], [ 1.68292595e+00, -4.60914204e-01], [ 8.33302236e-01, 3.98915558e-01], [ 3.27488923e-01, 8.78131092e-01], [ 6.74827352e-01, -3.35549427e-01], [-2.45527648e-01, 9.38992348e-01], [ 1.28155163e+00, -4.76190226e-01], [ 3.01408950e-01, -8.43203922e-02], [ 5.80344402e-01, 1.01606800e+00], [ 6.54257532e-01, 7.78195644e-01], [ 2.18332379e-01, 2.39832599e-02], [ 1.17599534e+00, -5.14188256e-01], [ 6.53498000e-01, 8.42628380e-01], [ 1.46494235e+00, -2.88413814e-01], [ 4.50729150e-01, -4.81782194e-01], [ 1.59955423e+00, -1.97918553e-02], [ 4.23800432e-01, 8.61219660e-01], [ 6.47861265e-01, -3.73026171e-01], [ 1.54832195e+00, -3.82636573e-01], [ 9.59507036e-01, -6.53322764e-01], [ 6.22396337e-01, 8.80914616e-01], [-1.14279515e+00, 5.33863108e-02], [-9.68196502e-01, 5.38851311e-01], [-6.91219117e-01, 4.43566916e-01], [ 1.82600270e-01, 1.76667308e-02], [ 9.34822203e-01, -3.97721167e-01], [ 1.70786789e+00, 5.01289734e-01], [-6.81686209e-01, 6.25964062e-01], [-7.06439154e-01, 5.56250067e-01], [ 5.01454747e-01, 7.38339946e-01], [ 1.47648298e+00, -3.69367080e-01], [-9.33946572e-01, 4.76649526e-01], [ 1.91815067e+00, 2.06055971e-01], [ 1.87199356e+00, -3.36720051e-01], [ 1.94311158e+00, 3.13223536e-01], [ 3.42365360e-01, 1.01614096e+00], [-7.12211574e-02, 8.20286056e-01], [ 8.38200245e-01, -5.80104512e-01], [-8.34317448e-01, 3.91907169e-01], [-6.52120149e-01, 7.85982383e-01], [-9.31767725e-01, 4.03435241e-01], [ 4.33542139e-01, 9.55743482e-01], [ 8.31174752e-01, 4.53839488e-01], [ 9.20525579e-01, 2.60496062e-01], [-3.67116967e-02, 3.42448285e-01], [-5.45442229e-01, 8.70811316e-01], [ 1.79020519e+00, 2.09022313e-01], [ 2.09486633e-01, -7.84617937e-04], [ 1.69222125e-01, 1.01894845e+00], [ 1.58819062e+00, -1.90820935e-01]]), array([0, 1, 1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 1, 1, 0, 0, 0, 1, 0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 0, 0, 1, 1, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 1, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 0, 1, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 0, 1, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 1]))
X = moons[0]
y = moons[1]
plt.figure(figsize=(12, 8))
plt.scatter(X[:, 0], X[:, 1])
plt.show()
X = moons[0]
y = moons[1]
plt.figure(figsize=(12, 8))
mglearn.discrete_scatter(X[:, 0], X[:, 1], y)
plt.show()
2次元空間を10次元特徴空間に写像して分類する。 (カーネル法は使用しない)
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.preprocessing import PolynomialFeatures
from sklearn.svm import LinearSVC
X_train, X_test, y_train, y_test = train_test_split(X, y, stratify=y, random_state=0)
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.fit_transform(X_test)
lin_svm = LinearSVC().fit(X_train_scaled, y_train)
plt.figure(figsize=(12, 8))
mglearn.plots.plot_2d_separator(lin_svm, X)
mglearn.discrete_scatter(X[:, 0], X[:, 1], y)
plt.xlabel("$x_0s", fontsize=20)
plt.ylabel("$x_1s", fontsize=20)
# 線形分離不可能であることを、いちおう確認
Text(0, 0.5, '$x_1s')
poly = PolynomialFeatures(degree=3)
X_train_poly = poly.fit_transform(X_train)
X_test_poly = poly.fit_transform(X_test)
print(X_train_poly.shape) # 10次元になっていることがわかる。
print(X_train_poly)
(150, 10) [[ 1.00000000e+00 8.86813848e-01 -4.84896235e-01 ... -3.81341214e-01 2.08511538e-01 -1.14010917e-01] [ 1.00000000e+00 9.24799914e-01 4.94247065e-01 ... 4.22707215e-01 2.25910272e-01 1.20734753e-01] [ 1.00000000e+00 -7.63797985e-01 8.26528213e-01 ... 4.82186114e-01 -5.21787743e-01 5.64641828e-01] ... [ 1.00000000e+00 -9.68196502e-01 5.38851311e-01 ... 5.05121626e-01 -2.81126249e-01 1.56461263e-01] [ 1.00000000e+00 -6.70964946e-01 7.93803059e-01 ... 3.57365342e-01 -4.22790644e-01 5.00193800e-01] [ 1.00000000e+00 -2.68526403e-02 5.04722011e-01 ... 3.63937020e-04 -6.84055729e-03 1.28575060e-01]]
# これで10次元がどのように構成されているかが分かる。
poly.get_feature_names()
['1', 'x0', 'x1', 'x0^2', 'x0 x1', 'x1^2', 'x0^3', 'x0^2 x1', 'x0 x1^2', 'x1^3']
X_train_scaled_poly = scaler.fit_transform(X_train_poly)
X_test_scaled_poly = scaler.fit_transform(X_test_poly)
lin_svm = LinearSVC().fit(X_train_scaled_poly, y_train)
lin_svm.predict(X_test_scaled_poly) == y_test
# 的中率をみる → パーフェクト
array([ True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True])