by Alejandro Correa Bahnsen and Jesus Solano
version 1.4, January 2019
This notebook is licensed under a Creative Commons Attribution-ShareAlike 3.0 Unported License. Special thanks goes to Rick Muller, Sandia National Laboratories
Why are we learning about decision trees?
Students will be able to:
Major League Baseball player data from 1986-87:
Group exercise:
Rules for segmenting:
Above are the regions created by a computer:
Note: Years and Hits are both integers, but the convention is to use the midpoint between adjacent values to label a split.
These regions are used to make predictions on out-of-sample data. Thus, there are only three possible predictions! (Is this different from how linear regression makes predictions?)
Below is the equivalent regression tree:
The first split is Years < 4.5, thus that split goes at the top of the tree. When a splitting rule is True, you follow the left branch. When a splitting rule is False, you follow the right branch.
For players in the left branch, the mean Salary is $166,000, thus you label it with that value. (Salary has been divided by 1000 and log-transformed to 5.11.)
For players in the right branch, there is a further split on Hits < 117.5, dividing players into two more Salary regions: $403,000 (transformed to 6.00), and $846,000 (transformed to 6.74).
What does this tree tell you about your data?
Question: What do you like and dislike about decision trees so far?
Your training data is a tiny dataset of used vehicle sale prices. Your goal is to predict price for testing data.
group_by
).Ideal approach: Consider every possible partition of the feature space (computationally infeasible)
"Good enough" approach: recursive binary splitting
# vehicle data
import pandas as pd
url = 'https://github.com/albahnsen/PracticalMachineLearningClass/raw/master/datasets/vehicles_train.csv'
train = pd.read_csv(url)
# before splitting anything, just predict the mean of the entire dataset
train['prediction'] = train.price.mean()
train
price | year | miles | doors | vtype | prediction | |
---|---|---|---|---|---|---|
0 | 22000 | 2012 | 13000 | 2 | car | 6571.428571 |
1 | 14000 | 2010 | 30000 | 2 | car | 6571.428571 |
2 | 13000 | 2010 | 73500 | 4 | car | 6571.428571 |
3 | 9500 | 2009 | 78000 | 4 | car | 6571.428571 |
4 | 9000 | 2007 | 47000 | 4 | car | 6571.428571 |
5 | 4000 | 2006 | 124000 | 2 | car | 6571.428571 |
6 | 3000 | 2004 | 177000 | 4 | car | 6571.428571 |
7 | 2000 | 2004 | 209000 | 4 | truck | 6571.428571 |
8 | 3000 | 2003 | 138000 | 2 | car | 6571.428571 |
9 | 1900 | 2003 | 160000 | 4 | car | 6571.428571 |
10 | 2500 | 2003 | 190000 | 2 | truck | 6571.428571 |
11 | 5000 | 2001 | 62000 | 4 | car | 6571.428571 |
12 | 1800 | 1999 | 163000 | 2 | truck | 6571.428571 |
13 | 1300 | 1997 | 138000 | 4 | car | 6571.428571 |
year = 2010
train['pred'] = train.loc[train.year<year, 'price'].mean()
train.loc[train.year>=year, 'pred'] = train.loc[train.year>=year, 'price'].mean()
(((train['price'] - train['pred'])**2).mean()) ** 0.5
3042.7402778200435
train_izq = train.loc[train.year<2010].copy()
train_izq.year.unique()
array([2009, 2007, 2006, 2004, 2003, 2001, 1999, 1997])
def error_año(train, year):
train['pred'] = train.loc[train.year<year, 'price'].mean()
train.loc[train.year>=year, 'pred'] = train.loc[train.year>=year, 'price'].mean()
print ((((train['price'] - train['pred'])**2).mean()) ** 0.5)
def error_miles(train, miles):
train['pred'] = train.loc[train.miles<miles, 'price'].mean()
train.loc[train.miles>=miles, 'pred'] = train.loc[train.miles>=miles, 'price'].mean()
print ((((train['price'] - train['pred'])**2).mean()) ** 0.5)
for year in train_izq.year.unique():
print('Year ',year)
error_año(train_izq, year)
Year 2009 2057.469761182851 Year 2007 1009.9754972525349 Year 2006 1588.559953943422 Year 2004 2291.254783922662 Year 2003 2609.750075111687 Year 2001 2474.322680507279 Year 1999 2584.235424119236 Year 1997 2712.7492078079777
train_izq.miles.describe()
count 11.000000 mean 135090.909091 std 53042.350147 min 47000.000000 25% 101000.000000 50% 138000.000000 75% 170000.000000 max 209000.000000 Name: miles, dtype: float64
for miles in [50000, 90000, 95000, 100000, 105000, 110000, 125000, 140000, 160000, 180000]:
print('Miles ',miles)
error_miles(train_izq, miles)
Miles 50000 2183.408511312697 Miles 90000 1258.6217811077274 Miles 95000 1258.6217811077274 Miles 100000 1258.6217811077274 Miles 105000 1258.6217811077274 Miles 110000 1258.6217811077274 Miles 125000 1527.2099167665622 Miles 140000 2244.42744267988 Miles 160000 2244.42744267988 Miles 180000 2597.5610160924484
train_der = train.loc[train.year>=2010].copy()
for year in train_der.year.unique():
print('Year ',year)
error_año(train_der, year)
print('----------------------------------')
for miles in [50000, 90000, 95000, 100000, 105000, 110000, 125000, 140000, 160000, 180000]:
print('Miles ',miles)
error_miles(train_der, miles)
Year 2012 408.248290463863 Year 2010 4027.681991198191 ---------------------------------- Miles 50000 3265.986323710904 Miles 90000 4027.681991198191 Miles 95000 4027.681991198191 Miles 100000 4027.681991198191 Miles 105000 4027.681991198191 Miles 110000 4027.681991198191 Miles 125000 4027.681991198191 Miles 140000 4027.681991198191 Miles 160000 4027.681991198191 Miles 180000 4027.681991198191
train_der_izq = train_der.loc[train_der.year<2012].copy()
for year in train_der_izq.year.unique():
print('Year ',year)
error_año(train_der_izq, year)
print('----------------------------------')
for miles in [25000, 50000, 90000, 95000, 100000, 105000, 110000, 125000, 140000, 160000, 180000]:
print('Miles ',miles)
error_miles(train_der_izq, miles)
Year 2010 500.0 ---------------------------------- Miles 25000 500.0 Miles 50000 0.0 Miles 90000 500.0 Miles 95000 500.0 Miles 100000 500.0 Miles 105000 500.0 Miles 110000 500.0 Miles 125000 500.0 Miles 140000 500.0 Miles 160000 500.0 Miles 180000 500.0
train_der_izq
price | year | miles | doors | vtype | prediction | pred | |
---|---|---|---|---|---|---|---|
1 | 14000 | 2010 | 30000 | 2 | car | 6571.428571 | 13500.0 |
2 | 13000 | 2010 | 73500 | 4 | car | 6571.428571 | 13500.0 |
train_izq_izq = train_izq.loc[train_izq.year<2007].copy()
train_izq_izq
price | year | miles | doors | vtype | prediction | pred | |
---|---|---|---|---|---|---|---|
5 | 4000 | 2006 | 124000 | 2 | car | 6571.428571 | 4277.777778 |
6 | 3000 | 2004 | 177000 | 4 | car | 6571.428571 | 4277.777778 |
7 | 2000 | 2004 | 209000 | 4 | truck | 6571.428571 | 2250.000000 |
8 | 3000 | 2003 | 138000 | 2 | car | 6571.428571 | 4277.777778 |
9 | 1900 | 2003 | 160000 | 4 | car | 6571.428571 | 4277.777778 |
10 | 2500 | 2003 | 190000 | 2 | truck | 6571.428571 | 2250.000000 |
11 | 5000 | 2001 | 62000 | 4 | car | 6571.428571 | 4277.777778 |
12 | 1800 | 1999 | 163000 | 2 | truck | 6571.428571 | 4277.777778 |
13 | 1300 | 1997 | 138000 | 4 | car | 6571.428571 | 4277.777778 |
for year in train_izq_izq.year.unique():
print('Year ',year)
error_año(train_izq_izq, year)
print('----------------------------------')
for miles in [25000, 50000, 90000, 95000, 100000, 105000, 110000, 125000, 140000, 160000, 180000]:
print('Miles ',miles)
error_miles(train_izq_izq, miles)
Year 2006 1014.2731387550397 Year 2004 1092.8216960050065 Year 2003 1110.2218663819374 Year 2001 916.6450213894664 Year 1999 989.9494936611666 Year 1997 1110.3330609203888 ---------------------------------- Miles 25000 1110.3330609203888 Miles 50000 1110.3330609203888 Miles 90000 764.3988196979084 Miles 95000 764.3988196979084 Miles 100000 764.3988196979084 Miles 105000 764.3988196979084 Miles 110000 764.3988196979084 Miles 125000 574.3180911666199 Miles 140000 970.6527013647398 Miles 160000 970.6527013647398 Miles 180000 1081.2617556017526
train_izq_der = train_izq.loc[train_izq.year>=2007].copy()
train_izq_der
price | year | miles | doors | vtype | prediction | pred | |
---|---|---|---|---|---|---|---|
3 | 9500 | 2009 | 78000 | 4 | car | 6571.428571 | 4277.777778 |
4 | 9000 | 2007 | 47000 | 4 | car | 6571.428571 | 4277.777778 |
train_der_der = train_der.loc[train_der.year>=2012].copy()
train_der_der
price | year | miles | doors | vtype | prediction | pred | |
---|---|---|---|---|---|---|---|
0 | 22000 | 2012 | 13000 | 2 | car | 6571.428571 | 16333.333333 |
train_izq_izq_izq = train_izq_izq.loc[train_izq_izq.miles<125000].copy()
train_izq_izq_izq
price | year | miles | doors | vtype | prediction | pred | |
---|---|---|---|---|---|---|---|
5 | 4000 | 2006 | 124000 | 2 | car | 6571.428571 | 2857.142857 |
11 | 5000 | 2001 | 62000 | 4 | car | 6571.428571 | 2857.142857 |
train_izq_izq_der = train_izq_izq.loc[train_izq_izq.miles>=125000].copy()
train_izq_izq_der
price | year | miles | doors | vtype | prediction | pred | |
---|---|---|---|---|---|---|---|
6 | 3000 | 2004 | 177000 | 4 | car | 6571.428571 | 2857.142857 |
7 | 2000 | 2004 | 209000 | 4 | truck | 6571.428571 | 2250.000000 |
8 | 3000 | 2003 | 138000 | 2 | car | 6571.428571 | 2857.142857 |
9 | 1900 | 2003 | 160000 | 4 | car | 6571.428571 | 2857.142857 |
10 | 2500 | 2003 | 190000 | 2 | truck | 6571.428571 | 2250.000000 |
12 | 1800 | 1999 | 163000 | 2 | truck | 6571.428571 | 2857.142857 |
13 | 1300 | 1997 | 138000 | 4 | car | 6571.428571 | 2857.142857 |
for year in train_izq_izq_der.year.unique():
print('Year ',year)
error_año(train_izq_izq_der, year)
print('----------------------------------')
for miles in [140000, 160000, 180000]:
print('Miles ',miles)
error_miles(train_izq_izq_der, miles)
Year 2004 565.685424949238 Year 2003 419.69376590897457 Year 1999 461.8802153517006 Year 1997 593.8459911664722 ---------------------------------- Miles 140000 592.452529743945 Miles 160000 592.452529743945 Miles 180000 593.416259587532
train_izq_izq_der_izq = train_izq_izq_der.loc[train_izq_izq_der.year<2003].copy()
train_izq_izq_der_izq
price | year | miles | doors | vtype | prediction | pred | |
---|---|---|---|---|---|---|---|
12 | 1800 | 1999 | 163000 | 2 | truck | 6571.428571 | 2200.0 |
13 | 1300 | 1997 | 138000 | 4 | car | 6571.428571 | 2200.0 |
train_izq_izq_der_der = train_izq_izq_der.loc[train_izq_izq_der.year>=2003].copy()
train_izq_izq_der_der
price | year | miles | doors | vtype | prediction | pred | |
---|---|---|---|---|---|---|---|
6 | 3000 | 2004 | 177000 | 4 | car | 6571.428571 | 2200.0 |
7 | 2000 | 2004 | 209000 | 4 | truck | 6571.428571 | 2250.0 |
8 | 3000 | 2003 | 138000 | 2 | car | 6571.428571 | 2200.0 |
9 | 1900 | 2003 | 160000 | 4 | car | 6571.428571 | 2200.0 |
10 | 2500 | 2003 | 190000 | 2 | truck | 6571.428571 | 2250.0 |
for year in train_izq_izq_der_der.year.unique():
print('Year ',year)
error_año(train_izq_izq_der_der, year)
print('----------------------------------')
for miles in [140000, 160000, 180000]:
print('Miles ',miles)
error_miles(train_izq_izq_der_der, miles)
Year 2004 470.46076705006266 Year 2003 470.7440918375928 ---------------------------------- Miles 140000 392.42833740697165 Miles 160000 392.42833740697165 Miles 180000 431.66344915145794
# calculate RMSE for those predictions
from sklearn import metrics
import numpy as np
np.sqrt(metrics.mean_squared_error(train.price, train.prediction))
5936.981985995983
# define a function that calculates the RMSE for a given split of miles
def mileage_split(miles):
lower_mileage_price = train[train.miles < miles].price.mean()
higher_mileage_price = train[train.miles >= miles].price.mean()
train['prediction'] = np.where(train.miles < miles, lower_mileage_price, higher_mileage_price)
return np.sqrt(metrics.mean_squared_error(train.price, train.prediction))
# calculate RMSE for tree which splits on miles < 50000
print('RMSE:', mileage_split(50000))
train
RMSE: 3984.0917425414564
price | year | miles | doors | vtype | prediction | pred | |
---|---|---|---|---|---|---|---|
0 | 22000 | 2012 | 13000 | 2 | car | 15000.000000 | 16333.333333 |
1 | 14000 | 2010 | 30000 | 2 | car | 15000.000000 | 16333.333333 |
2 | 13000 | 2010 | 73500 | 4 | car | 4272.727273 | 16333.333333 |
3 | 9500 | 2009 | 78000 | 4 | car | 4272.727273 | 3909.090909 |
4 | 9000 | 2007 | 47000 | 4 | car | 15000.000000 | 3909.090909 |
5 | 4000 | 2006 | 124000 | 2 | car | 4272.727273 | 3909.090909 |
6 | 3000 | 2004 | 177000 | 4 | car | 4272.727273 | 3909.090909 |
7 | 2000 | 2004 | 209000 | 4 | truck | 4272.727273 | 3909.090909 |
8 | 3000 | 2003 | 138000 | 2 | car | 4272.727273 | 3909.090909 |
9 | 1900 | 2003 | 160000 | 4 | car | 4272.727273 | 3909.090909 |
10 | 2500 | 2003 | 190000 | 2 | truck | 4272.727273 | 3909.090909 |
11 | 5000 | 2001 | 62000 | 4 | car | 4272.727273 | 3909.090909 |
12 | 1800 | 1999 | 163000 | 2 | truck | 4272.727273 | 3909.090909 |
13 | 1300 | 1997 | 138000 | 4 | car | 4272.727273 | 3909.090909 |
# calculate RMSE for tree which splits on miles < 100000
print('RMSE:', mileage_split(100000))
train
RMSE: 3530.146530076269
price | year | miles | doors | vtype | prediction | pred | |
---|---|---|---|---|---|---|---|
0 | 22000 | 2012 | 13000 | 2 | car | 12083.333333 | 16333.333333 |
1 | 14000 | 2010 | 30000 | 2 | car | 12083.333333 | 16333.333333 |
2 | 13000 | 2010 | 73500 | 4 | car | 12083.333333 | 16333.333333 |
3 | 9500 | 2009 | 78000 | 4 | car | 12083.333333 | 3909.090909 |
4 | 9000 | 2007 | 47000 | 4 | car | 12083.333333 | 3909.090909 |
5 | 4000 | 2006 | 124000 | 2 | car | 2437.500000 | 3909.090909 |
6 | 3000 | 2004 | 177000 | 4 | car | 2437.500000 | 3909.090909 |
7 | 2000 | 2004 | 209000 | 4 | truck | 2437.500000 | 3909.090909 |
8 | 3000 | 2003 | 138000 | 2 | car | 2437.500000 | 3909.090909 |
9 | 1900 | 2003 | 160000 | 4 | car | 2437.500000 | 3909.090909 |
10 | 2500 | 2003 | 190000 | 2 | truck | 2437.500000 | 3909.090909 |
11 | 5000 | 2001 | 62000 | 4 | car | 12083.333333 | 3909.090909 |
12 | 1800 | 1999 | 163000 | 2 | truck | 2437.500000 | 3909.090909 |
13 | 1300 | 1997 | 138000 | 4 | car | 2437.500000 | 3909.090909 |
# check all possible mileage splits
mileage_range = range(train.miles.min(), train.miles.max(), 1000)
RMSE = [mileage_split(miles) for miles in mileage_range]
# allow plots to appear in the notebook
%matplotlib inline
import matplotlib.pyplot as plt
plt.rcParams['figure.figsize'] = (6, 4)
plt.rcParams['font.size'] = 14
# plot mileage cutpoint (x-axis) versus RMSE (y-axis)
plt.plot(mileage_range, RMSE)
plt.xlabel('Mileage cutpoint')
plt.ylabel('RMSE (lower is better)')
Text(0, 0.5, 'RMSE (lower is better)')
Recap: Before every split, this process is repeated for every feature, and the feature and cutpoint that produces the lowest MSE is chosen.
# encode car as 0 and truck as 1
train['vtype'] = train.vtype.map({'car':0, 'truck':1})
# define X and y
feature_cols = ['year', 'miles', 'doors', 'vtype']
X = train[feature_cols]
y = train.price
# instantiate a DecisionTreeRegressor (with random_state=1)
from sklearn.tree import DecisionTreeRegressor
treereg = DecisionTreeRegressor(random_state=1)
treereg
DecisionTreeRegressor(criterion='mse', max_depth=None, max_features=None, max_leaf_nodes=None, min_impurity_decrease=0.0, min_impurity_split=None, min_samples_leaf=1, min_samples_split=2, min_weight_fraction_leaf=0.0, presort=False, random_state=1, splitter='best')
# use leave-one-out cross-validation (LOOCV) to estimate the RMSE for this model
import numpy as np
from sklearn.model_selection import cross_val_score
scores = cross_val_score(treereg, X, y, cv=14, scoring='neg_mean_squared_error')
np.mean(np.sqrt(-scores))
3107.1428571428573
The training error continues to go down as the tree size increases (due to overfitting), but the lowest cross-validation error occurs for a tree with 3 leaves.
Let's try to reduce the RMSE by tuning the max_depth parameter:
# try different values one-by-one
treereg = DecisionTreeRegressor(max_depth=1, random_state=1)
scores = cross_val_score(treereg, X, y, cv=14, scoring='neg_mean_squared_error')
np.mean(np.sqrt(-scores))
4050.1443001443
Or, we could write a loop to try a range of values:
# list of values to try
max_depth_range = range(1, 8)
# list to store the average RMSE for each value of max_depth
RMSE_scores = []
# use LOOCV with each value of max_depth
for depth in max_depth_range:
treereg = DecisionTreeRegressor(max_depth=depth, random_state=1)
MSE_scores = cross_val_score(treereg, X, y, cv=14, scoring='neg_mean_squared_error')
RMSE_scores.append(np.mean(np.sqrt(-MSE_scores)))
%matplotlib inline
import matplotlib.pyplot as plt
# plot max_depth (x-axis) versus RMSE (y-axis)
plt.plot(max_depth_range, RMSE_scores)
plt.xlabel('max_depth')
plt.ylabel('RMSE (lower is better)')
Text(0, 0.5, 'RMSE (lower is better)')
# max_depth=3 was best, so fit a tree using that parameter
treereg = DecisionTreeRegressor(max_depth=3, random_state=1)
treereg.fit(X, y)
DecisionTreeRegressor(criterion='mse', max_depth=3, max_features=None, max_leaf_nodes=None, min_impurity_decrease=0.0, min_impurity_split=None, min_samples_leaf=1, min_samples_split=2, min_weight_fraction_leaf=0.0, presort=False, random_state=1, splitter='best')
# "Gini importance" of each feature: the (normalized) total reduction of error brought by that feature
pd.DataFrame({'feature':feature_cols, 'importance':treereg.feature_importances_})
feature | importance | |
---|---|---|
0 | year | 0.798744 |
1 | miles | 0.201256 |
2 | doors | 0.000000 |
3 | vtype | 0.000000 |
# create a Graphviz file
from sklearn.tree import export_graphviz
export_graphviz(treereg, out_file='tree_vehicles.dot', feature_names=feature_cols)
# At the command line, run this to convert to PNG:
# dot -Tpng tree_vehicles.dot -o tree_vehicles.png
Reading the internal nodes:
Reading the leaves:
# read the testing data
url = 'https://raw.githubusercontent.com/justmarkham/DAT8/master/data/vehicles_test.csv'
test = pd.read_csv(url)
test['vtype'] = test.vtype.map({'car':0, 'truck':1})
test
price | year | miles | doors | vtype | |
---|---|---|---|---|---|
0 | 3000 | 2003 | 130000 | 4 | 1 |
1 | 6000 | 2005 | 82500 | 4 | 0 |
2 | 12000 | 2010 | 60000 | 2 | 0 |
Question: Using the tree diagram above, what predictions will the model make for each observation?
# use fitted model to make predictions on testing data
X_test = test[feature_cols]
y_test = test.price
y_pred = treereg.predict(X_test)
y_pred
array([ 4000., 5000., 13500.])
# calculate RMSE
from sklearn.metrics import mean_squared_error
np.sqrt(mean_squared_error(y_test, y_pred))
1190.2380714238084
Example: Predict whether Barack Obama or Hillary Clinton will win the Democratic primary in a particular county in 2008:
Questions:
regression trees | classification trees |
---|---|
predict a continuous response | predict a categorical response |
predict using mean response of each leaf | predict using most commonly occuring class of each leaf |
splits are chosen to minimize MSE | splits are chosen to minimize Gini index (discussed below) |
Common options for the splitting criteria:
Pretend we are predicting whether someone buys an iPhone or an Android:
Our goal in making splits is to reduce the classification error rate. Let's try splitting on gender:
Compare that with a split on age:
The decision tree algorithm will try every possible split across all features, and choose the split that reduces the error rate the most.
Calculate the Gini index before making a split:
$$1 - \left(\frac {iPhone} {Total}\right)^2 - \left(\frac {Android} {Total}\right)^2 = 1 - \left(\frac {10} {25}\right)^2 - \left(\frac {15} {25}\right)^2 = 0.48$$Evaluating the split on gender using Gini index:
$$\text{Males: } 1 - \left(\frac {2} {14}\right)^2 - \left(\frac {12} {14}\right)^2 = 0.24$$$$\text{Females: } 1 - \left(\frac {8} {11}\right)^2 - \left(\frac {3} {11}\right)^2 = 0.40$$$$\text{Weighted Average: } 0.24 \left(\frac {14} {25}\right) + 0.40 \left(\frac {11} {25}\right) = 0.31$$Evaluating the split on age using Gini index:
$$\text{30 or younger: } 1 - \left(\frac {4} {12}\right)^2 - \left(\frac {8} {12}\right)^2 = 0.44$$$$\text{31 or older: } 1 - \left(\frac {6} {13}\right)^2 - \left(\frac {7} {13}\right)^2 = 0.50$$$$\text{Weighted Average: } 0.44 \left(\frac {12} {25}\right) + 0.50 \left(\frac {13} {25}\right) = 0.47$$Again, the decision tree algorithm will try every possible split, and will choose the split that reduces the Gini index (and thus increases the "node purity") the most.
Note: There is another common splitting criteria called cross-entropy. It's numerically similar to Gini index, but slower to compute, thus it's not as popular as Gini index.
We'll build a classification tree using the Titanic data:
# read in the data
url = 'https://raw.githubusercontent.com/justmarkham/DAT8/master/data/titanic.csv'
titanic = pd.read_csv(url)
# encode female as 0 and male as 1
titanic['Sex'] = titanic.Sex.map({'female':0, 'male':1})
# fill in the missing values for age with the median age
titanic.Age.fillna(titanic.Age.median(), inplace=True)
# create a DataFrame of dummy variables for Embarked
embarked_dummies = pd.get_dummies(titanic.Embarked, prefix='Embarked')
embarked_dummies.drop(embarked_dummies.columns[0], axis=1, inplace=True)
# concatenate the original DataFrame and the dummy DataFrame
titanic = pd.concat([titanic, embarked_dummies], axis=1)
# print the updated DataFrame
titanic.head()
PassengerId | Survived | Pclass | Name | Sex | Age | SibSp | Parch | Ticket | Fare | Cabin | Embarked | Embarked_Q | Embarked_S | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 1 | 0 | 3 | Braund, Mr. Owen Harris | 1 | 22.0 | 1 | 0 | A/5 21171 | 7.2500 | NaN | S | 0 | 1 |
1 | 2 | 1 | 1 | Cumings, Mrs. John Bradley (Florence Briggs Th... | 0 | 38.0 | 1 | 0 | PC 17599 | 71.2833 | C85 | C | 0 | 0 |
2 | 3 | 1 | 3 | Heikkinen, Miss. Laina | 0 | 26.0 | 0 | 0 | STON/O2. 3101282 | 7.9250 | NaN | S | 0 | 1 |
3 | 4 | 1 | 1 | Futrelle, Mrs. Jacques Heath (Lily May Peel) | 0 | 35.0 | 1 | 0 | 113803 | 53.1000 | C123 | S | 0 | 1 |
4 | 5 | 0 | 3 | Allen, Mr. William Henry | 1 | 35.0 | 0 | 0 | 373450 | 8.0500 | NaN | S | 0 | 1 |
# define X and y
feature_cols = ['Pclass', 'Sex', 'Age', 'Embarked_Q', 'Embarked_S']
X = titanic[feature_cols]
y = titanic.Survived
# fit a classification tree with max_depth=3 on all data
from sklearn.tree import DecisionTreeClassifier
treeclf = DecisionTreeClassifier(max_depth=3, random_state=1)
treeclf.fit(X, y)
DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=3, max_features=None, max_leaf_nodes=None, min_impurity_decrease=0.0, min_impurity_split=None, min_samples_leaf=1, min_samples_split=2, min_weight_fraction_leaf=0.0, presort=False, random_state=1, splitter='best')
# create a Graphviz file
export_graphviz(treeclf, out_file='tree_titanic.dot', feature_names=feature_cols)
# At the command line, run this to convert to PNG:
# dot -Tpng tree_titanic.dot -o tree_titanic.png
Notice the split in the bottom right: the same class is predicted in both of its leaves. That split didn't affect the classification error rate, though it did increase the node purity, which is important because it increases the accuracy of our predicted probabilities.
# compute the feature importances
pd.DataFrame({'feature':feature_cols, 'importance':treeclf.feature_importances_})
feature | importance | |
---|---|---|
0 | Pclass | 0.242664 |
1 | Sex | 0.655584 |
2 | Age | 0.064494 |
3 | Embarked_Q | 0.000000 |
4 | Embarked_S | 0.037258 |
Advantages of decision trees:
Disadvantages of decision trees: