During the session we classified the Iris dataset with KNN. Now let's try to classify it with decision trees!
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import warnings
warnings.simplefilter("ignore", FutureWarning)
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier
import sklearn.datasets
from yellowbrick.classifier import class_prediction_error
from yellowbrick.features import feature_importances
from yellowbrick.model_selection import learning_curve
red, blue, green = sns.color_palette('Set1', 3)
sns.set(
style='ticks',
context='talk',
palette='Set1'
)
If you had a problem importing Yellowbrick, create a new cell and run the following command in it:
!python3 -m pip install yellowbrick
We start by loading the data.
X, y = sklearn.datasets.load_iris(return_X_y=True)
feature_names = ['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)']
df = pd.DataFrame(data=X, columns=feature_names)
df['target'] = y
df.head()
sepal length (cm) | sepal width (cm) | petal length (cm) | petal width (cm) | target | |
---|---|---|---|---|---|
0 | 5.1 | 3.5 | 1.4 | 0.2 | 0 |
1 | 4.9 | 3.0 | 1.4 | 0.2 | 0 |
2 | 4.7 | 3.2 | 1.3 | 0.2 | 0 |
3 | 4.6 | 3.1 | 1.5 | 0.2 | 0 |
4 | 5.0 | 3.6 | 1.4 | 0.2 | 0 |
First, split the data to train and test sets.
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=0)
Second, create, fit and score a decision tree (or random forest) classifier.
classifier = DecisionTreeClassifier()
classifier.fit(X_train, y_train)
print("Accuracy = ", classifier.score(X_test, y_test))
Accuracy = 0.96
Plot a features importance plot to determine the most important features.
feature_importances(classifier, X, y, labels=feature_names);
Plot a learning curve to find out how many samples are required to train the model.
learning_curve(DecisionTreeClassifier(), X, y,
train_sizes=np.arange(0.1, 1.1, 0.1));
End.