Luokittelu - Gaussian Naive Bayes

Klassinen esimerkki luokittelusta on kurjenmiekkojen (iris) luokittelu kolmeen lajiin (setosa, versicolor, virginica) terä- (petal) ja verholehtien (sepal) koon mukaan. Seuraavassa käytän luokitteluun Gaussian Naive Bayes -menelmää.

Gaussian Naive Bayes -menetelmän idea

Menetelmä perustuu Bayesin teoreemaan. Jokaisen luokan kohdalla estimoidaan selittävien muuttujien jakaumat käyttäen normaalijakauma-oletusta. Tämän jälkeen voidaan laskea todennäköisyydet eri luokkiin kuulumiselle. Näiden todennäköisyyksien lukuarvot eivät sellaisenaan ole luotettavia. Olennaista on mallin toteuttama luokittelu.

Mallin oletuksena on, että selittävien muuttujien arvot ovat kussakin luokassa toisistaan riippumattomia. Käytännössä Gaussian Naive Bayes toimii hyvin monenlaisten datojen kohdalla vaikka riippumattomuusoletus ei toteutuisikaan.

In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline

#Vaikuttaa kaavioiden ulkoasuun:
sns.set()
In [2]:
#Esimerkkiaineisto löytyy seaborn-kirjastosta:
iris = sns.load_dataset('iris')
iris.head()
Out[2]:
sepal_length sepal_width petal_length petal_width species
0 5.1 3.5 1.4 0.2 setosa
1 4.9 3.0 1.4 0.2 setosa
2 4.7 3.2 1.3 0.2 setosa
3 4.6 3.1 1.5 0.2 setosa
4 5.0 3.6 1.4 0.2 setosa
In [3]:
#Seaborn-kirjaston pairplot havainnollistaa hyvin lajin (species) riippuvuutta petal- ja sepal-mitoista:
sns.pairplot(iris, hue='species')
C:\Users\Aki\Anaconda3\lib\site-packages\scipy\stats\stats.py:1713: FutureWarning: Using a non-tuple sequence for multidimensional indexing is deprecated; use `arr[tuple(seq)]` instead of `arr[seq]`. In the future this will be interpreted as an array index, `arr[np.array(seq)]`, which will result either in an error or a different result.
  return np.add.reduce(sorted[indexer] * weights, axis=axis) / sumval
Out[3]:
<seaborn.axisgrid.PairGrid at 0x284ae8f06d8>
In [4]:
#Feature-matriisi on iris-data ilman species-muuttujaa:
X = iris.drop('species', axis=1)

#Target on species (laji):
y = iris['species']
In [5]:
#train_test_split jakaa datan opetusdataan ja testidataan (25 % datasta, jollei toisin määrätä).
#random_state määrittää satunnaislukugeneraattorin siemenluvun. Sama siemenluku takaa saman jaottelun
#eri suorituskerroilla.

from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=5)
In [6]:
#Gaussian naive bayes -mallin tuonti:
from sklearn.naive_bayes import GaussianNB

#Mallin sovitus:
malli = GaussianNB()
malli.fit(X_train, y_train)

#Mallin mukaisten ennusteiden laskeminen opetusdatalle ja testidatalle:
y_train_malli = malli.predict(X_train)
y_test_malli = malli.predict(X_test)
In [7]:
#Oikeaan osuneiden ennusteiden osuus opetusdatassa:

from sklearn.metrics import accuracy_score

accuracy_score(y_train, y_train_malli)
Out[7]:
0.9732142857142857
In [8]:
#Oikeaan osuneiden ennusteiden osuus testidatassa:
accuracy_score(y_test, y_test_malli)
Out[8]:
0.9210526315789473
In [9]:
#Confusion-matriisi opetusdatalle:

from sklearn.metrics import confusion_matrix

print(confusion_matrix(y_train, y_train_malli))
[[38  0  0]
 [ 0 35  1]
 [ 0  2 36]]

Kaikki Setosa-lajiin kuuluvat ennustetaan oikein, yksi Versicolor ennustetaan virheellisesti Virginica-lajiin kuuluvaksi, kaksi Virginica-lajiin kuuluvaa ennustetaan virheellisesti Versicolor-lajiin kuuluviksi.

In [10]:
#Confusion-matriisi testidatalle:

print(confusion_matrix(y_test, y_test_malli))
[[12  0  0]
 [ 0 13  1]
 [ 0  2 10]]

Kaikki Setosa-lajiin kuuluvat ennustetaan oikein, yksi Versicolor ennustetaan virheellisesti Virginica-lajiin kuuluvaksi, kaksi Virginica-lajiin kuuluvaa ennustetaan virheellisesti Versicolor-lajiin kuuluviksi.

In [11]:
#Uusi data, jota ei ole valmiiksi luokiteltu:
Xnew = pd.read_excel('http://taanila.fi/irisnew.xlsx')
Xnew
Out[11]:
sepal_length sepal_width petal_length petal_width
0 5.0 3.5 1.5 0.3
1 8.1 3.3 6.5 1.9
2 6.0 3.0 3.0 0.5
In [12]:
#Luokittelu:
malli.predict(Xnew)
Out[12]:
array(['setosa', 'virginica', 'versicolor'], dtype='<U10')