We're trying to predict the high_income column.
import pandas as pd, numpy as np
import matplotlib.pyplot as plt, seaborn as sns
sns.set(style = "whitegrid", font_scale = 1.2)
%matplotlib inline
income = pd.read_csv("income.csv")
income.info()
<class 'pandas.core.frame.DataFrame'> RangeIndex: 32561 entries, 0 to 32560 Data columns (total 15 columns): # Column Non-Null Count Dtype --- ------ -------------- ----- 0 age 32561 non-null int64 1 workclass 32561 non-null object 2 fnlwgt 32561 non-null int64 3 education 32561 non-null object 4 education_num 32561 non-null int64 5 marital_status 32561 non-null object 6 occupation 32561 non-null object 7 relationship 32561 non-null object 8 race 32561 non-null object 9 sex 32561 non-null object 10 capital_gain 32561 non-null int64 11 capital_loss 32561 non-null int64 12 hours_per_week 32561 non-null int64 13 native_country 32561 non-null object 14 high_income 32561 non-null object dtypes: int64(6), object(9) memory usage: 3.7+ MB
pd.options.display.float_format = "{:.5f}".format
income.describe()
age | fnlwgt | education_num | capital_gain | capital_loss | hours_per_week | |
---|---|---|---|---|---|---|
count | 32561.00000 | 32561.00000 | 32561.00000 | 32561.00000 | 32561.00000 | 32561.00000 |
mean | 38.58165 | 189778.36651 | 10.08068 | 1077.64884 | 87.30383 | 40.43746 |
std | 13.64043 | 105549.97770 | 2.57272 | 7385.29208 | 402.96022 | 12.34743 |
min | 17.00000 | 12285.00000 | 1.00000 | 0.00000 | 0.00000 | 1.00000 |
25% | 28.00000 | 117827.00000 | 9.00000 | 0.00000 | 0.00000 | 40.00000 |
50% | 37.00000 | 178356.00000 | 10.00000 | 0.00000 | 0.00000 | 40.00000 |
75% | 48.00000 | 237051.00000 | 12.00000 | 0.00000 | 0.00000 | 45.00000 |
max | 90.00000 | 1484705.00000 | 16.00000 | 99999.00000 | 4356.00000 | 99.00000 |
income.describe(include = "object")
workclass | education | marital_status | occupation | relationship | race | sex | native_country | high_income | |
---|---|---|---|---|---|---|---|---|---|
count | 32561 | 32561 | 32561 | 32561 | 32561 | 32561 | 32561 | 32561 | 32561 |
unique | 9 | 16 | 7 | 15 | 6 | 5 | 2 | 42 | 2 |
top | Private | HS-grad | Married-civ-spouse | Prof-specialty | Husband | White | Male | United-States | <=50K |
freq | 22696 | 10501 | 14976 | 4140 | 13193 | 27816 | 21790 | 29170 | 24720 |
income.head()
age | workclass | fnlwgt | education | education_num | marital_status | occupation | relationship | race | sex | capital_gain | capital_loss | hours_per_week | native_country | high_income | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 39 | State-gov | 77516 | Bachelors | 13 | Never-married | Adm-clerical | Not-in-family | White | Male | 2174 | 0 | 40 | United-States | <=50K |
1 | 50 | Self-emp-not-inc | 83311 | Bachelors | 13 | Married-civ-spouse | Exec-managerial | Husband | White | Male | 0 | 0 | 13 | United-States | <=50K |
2 | 38 | Private | 215646 | HS-grad | 9 | Divorced | Handlers-cleaners | Not-in-family | White | Male | 0 | 0 | 40 | United-States | <=50K |
3 | 53 | Private | 234721 | 11th | 7 | Married-civ-spouse | Handlers-cleaners | Husband | Black | Male | 0 | 0 | 40 | United-States | <=50K |
4 | 28 | Private | 338409 | Bachelors | 13 | Married-civ-spouse | Prof-specialty | Wife | Black | Female | 0 | 0 | 40 | Cuba | <=50K |
numeric_cols = income.select_dtypes(include = "int64").columns
numeric_cols, len(numeric_cols)
(Index(['age', 'fnlwgt', 'education_num', 'capital_gain', 'capital_loss', 'hours_per_week'], dtype='object'), 6)
fig, ax = plt.subplots(figsize = (15, 10), nrows = 2, ncols = 3)
for i in range(6):
if i < 3:
sns.histplot(data = income[numeric_cols[i]], ax = ax[0, i], bins = 10)
ax[0, i].set_title(numeric_cols[i])
ax[0, i].set_ylabel("")
ax[0, i].set_xlabel("")
else:
sns.histplot(data = income[numeric_cols[i]], ax = ax[1, i-3], bins = 10)
ax[1, i-3].set_title(numeric_cols[i])
ax[1, i-3].set_ylabel("")
ax[1, i-3].set_xlabel("")
categorical_cols = income.select_dtypes(include = "object").columns
print(categorical_cols, len(categorical_cols))
Index(['workclass', 'education', 'marital_status', 'occupation', 'relationship', 'race', 'sex', 'native_country', 'high_income'], dtype='object') 9
categorical_cols1, categorical_cols2 = [], []
for col in categorical_cols:
if len(income[col].unique()) < 10:
categorical_cols1.append(col)
else:
categorical_cols2.append(col)
print(categorical_cols1, "\n")
print(categorical_cols2)
['workclass', 'marital_status', 'relationship', 'race', 'sex', 'high_income'] ['education', 'occupation', 'native_country']
fig, ax = plt.subplots(figsize = (12, 8), nrows = 2, ncols = 3, squeeze = False)
for i in range(6):
if i < 3:
sns.countplot(x = income[categorical_cols1[i]], ax = ax[0, i])
ax[0, i].set_title(categorical_cols1[i])
ax[0, i].set_ylabel("")
ax[0, i].set_xlabel("")
ax[0, i].set_xticklabels(ax[0,i].get_xticklabels(), rotation = 90)
else:
sns.countplot(x = income[categorical_cols1[i]], ax = ax[1, i-3])
ax[1, i-3].set_title(categorical_cols1[i])
ax[1, i-3].set_ylabel("")
ax[1, i-3].set_xlabel("")
ax[1, i-3].set_xticklabels(ax[1,i-3].get_xticklabels(), rotation = 90)
plt.tight_layout()
plt.show()
fig, ax = plt.subplots(figsize = (12, 15), nrows = 3, ncols = 1)
for i in range(3):
sns.countplot(x = income[categorical_cols2[i]], ax = ax[i])
ax[i].set_title(categorical_cols2[i])
ax[i].set_ylabel("")
ax[i].set_xlabel("")
ax[i].set_xticklabels(labels = ax[i].get_xticklabels(), rotation = 90)
plt.tight_layout()
plt.show()
for col in categorical_cols:
print("{} | {}".format(col, income[col].unique()))
print("-------------------------------------------------------")
workclass | [' State-gov' ' Self-emp-not-inc' ' Private' ' Federal-gov' ' Local-gov' ' ?' ' Self-emp-inc' ' Without-pay' ' Never-worked'] ------------------------------------------------------- education | [' Bachelors' ' HS-grad' ' 11th' ' Masters' ' 9th' ' Some-college' ' Assoc-acdm' ' Assoc-voc' ' 7th-8th' ' Doctorate' ' Prof-school' ' 5th-6th' ' 10th' ' 1st-4th' ' Preschool' ' 12th'] ------------------------------------------------------- marital_status | [' Never-married' ' Married-civ-spouse' ' Divorced' ' Married-spouse-absent' ' Separated' ' Married-AF-spouse' ' Widowed'] ------------------------------------------------------- occupation | [' Adm-clerical' ' Exec-managerial' ' Handlers-cleaners' ' Prof-specialty' ' Other-service' ' Sales' ' Craft-repair' ' Transport-moving' ' Farming-fishing' ' Machine-op-inspct' ' Tech-support' ' ?' ' Protective-serv' ' Armed-Forces' ' Priv-house-serv'] ------------------------------------------------------- relationship | [' Not-in-family' ' Husband' ' Wife' ' Own-child' ' Unmarried' ' Other-relative'] ------------------------------------------------------- race | [' White' ' Black' ' Asian-Pac-Islander' ' Amer-Indian-Eskimo' ' Other'] ------------------------------------------------------- sex | [' Male' ' Female'] ------------------------------------------------------- native_country | [' United-States' ' Cuba' ' Jamaica' ' India' ' ?' ' Mexico' ' South' ' Puerto-Rico' ' Honduras' ' England' ' Canada' ' Germany' ' Iran' ' Philippines' ' Italy' ' Poland' ' Columbia' ' Cambodia' ' Thailand' ' Ecuador' ' Laos' ' Taiwan' ' Haiti' ' Portugal' ' Dominican-Republic' ' El-Salvador' ' France' ' Guatemala' ' China' ' Japan' ' Yugoslavia' ' Peru' ' Outlying-US(Guam-USVI-etc)' ' Scotland' ' Trinadad&Tobago' ' Greece' ' Nicaragua' ' Vietnam' ' Hong' ' Ireland' ' Hungary' ' Holand-Netherlands'] ------------------------------------------------------- high_income | [' <=50K' ' >50K'] -------------------------------------------------------
income[income["workclass"] == " ?"].shape
(1836, 15)
income.loc[income["workclass"] == " ?", "workclass"] = "Unknown"
income["workclass"].unique()
array([' State-gov', ' Self-emp-not-inc', ' Private', ' Federal-gov', ' Local-gov', 'Unknown', ' Self-emp-inc', ' Without-pay', ' Never-worked'], dtype=object)
# to create a checkpoint
preprocessed_income = income.copy()
for each_col in categorical_cols:
preprocessed_income[each_col] = pd.Categorical(preprocessed_income[each_col]).codes
preprocessed_income.head()
age | workclass | fnlwgt | education | education_num | marital_status | occupation | relationship | race | sex | capital_gain | capital_loss | hours_per_week | native_country | high_income | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 39 | 6 | 77516 | 9 | 13 | 4 | 1 | 1 | 4 | 1 | 2174 | 0 | 40 | 39 | 0 |
1 | 50 | 5 | 83311 | 9 | 13 | 2 | 4 | 0 | 4 | 1 | 0 | 0 | 13 | 39 | 0 |
2 | 38 | 3 | 215646 | 11 | 9 | 0 | 6 | 1 | 4 | 1 | 0 | 0 | 40 | 39 | 0 |
3 | 53 | 3 | 234721 | 1 | 7 | 2 | 6 | 0 | 2 | 1 | 0 | 0 | 40 | 39 | 0 |
4 | 28 | 3 | 338409 | 9 | 13 | 2 | 10 | 5 | 2 | 0 | 0 | 0 | 40 | 5 | 0 |
Categorical code for "Private" workclass for DQ is 4, as "?" has not been replaced with Unknown!
private_incomes = preprocessed_income[preprocessed_income["workclass"] == 3]
public_incomes = preprocessed_income[preprocessed_income["workclass"] != 3]
private_incomes.shape, public_incomes.shape
((22696, 15), (9865, 15))
prob_high_income = preprocessed_income.groupby("high_income")["high_income"].count()/preprocessed_income.shape[0]
prob_high_income
high_income 0 0.75919 1 0.24081 Name: high_income, dtype: float64
import math
bases = [math.e, 2, 4]
income_entropy = prob_high_income.apply(lambda x: x * math.log(x, 2)).sum() * -1
print(income_entropy)
0.7963839552022132
formula:
$IG(T,A) = Entropy(T) - \sum_{v \in A} \frac {|T_v|}{|T|} * Entropy(T_{v})$
Alternate explanation - We're finding the entropy of each set post-split, weighting it by the number of items in each split, then subtracting from the current entropy. If the result is positive, we've lowered entropy with our split. The higher the result is, the more we've lowered entropy.
age_median = preprocessed_income["age"].median()
print("Median Age is {}".format(age_median))
Median Age is 37.0
def calc_entropy(column):
counts = np.bincount(column)
probabilities = counts/ len(column)
entropy = 0
for prob in probabilities:
if prob > 0:
entropy += prob * math.log(prob, 2)
return -1 * entropy
median_or_less = preprocessed_income[preprocessed_income["age"] <= age_median]
high_median = preprocessed_income[preprocessed_income["age"] > age_median]
t_v_less = median_or_less.shape[0]/ preprocessed_income.shape[0]
t_v_high = high_median.shape[0]/ preprocessed_income.shape[0]
print("t_v_less = {} | t_v_high = {}".format(t_v_less, t_v_high))
t_v_less_entropy = calc_entropy(median_or_less["high_income"])
t_v_high_entropy = calc_entropy(high_median["high_income"])
print("t_v_less_entropy = {} | t_v_high_entropy = {}".format(t_v_less_entropy, t_v_high_entropy))
age_information_gain = income_entropy - np.sum(t_v_less * t_v_less_entropy + t_v_high * t_v_high_entropy)
print("information_gain:",age_information_gain)
t_v_less = 0.5122999907865238 | t_v_high = 0.4877000092134762 t_v_less_entropy = 0.5722871298658747 | t_v_high_entropy = 0.9353549188478923 information_gain: 0.047028661304691965
def calc_information_gain(df, feature_column, target_column):
column_median = df[feature_column].median()
left_split = df[df[feature_column] <= column_median]
right_split = df[df[feature_column] > column_median]
total_entropy = calc_entropy(df[target_column])
ratio_part = 0
for each_split in [left_split, right_split]:
prob = each_split.shape[0]/ df.shape[0]
ratio_part += prob * calc_entropy(each_split[target_column])
return total_entropy - ratio_part
feature_cols = ["age", "workclass", "education_num", "marital_status", "occupation",
"relationship", "race", "sex", "hours_per_week", "native_country"]
information_gains = []
for each_col in feature_cols:
information_gains.append(calc_information_gain(preprocessed_income, each_col, "high_income"))
print(information_gains)
highest_gain = feature_cols[information_gains.index(max(information_gains))]
highest_gain
[0.047028661304691965, 0.0013883016155813444, 0.06501298413277423, 0.1114272573715438, 0.0015822303843424645, 0.04736241665026941, 0.0, 0.0, 0.04062246867123487, 0.00013457344495848567]
'marital_status'
def id3(data, target, columns): 1. Create a node for the tree 2. If all values of the target attribute are 1, Return the node, with label = 1 3. If all values of the target attribute are 0, Return the node, with label = 0 4. Using information gain, find A, the column that splits the data best 5. Find the median value in column A 6. Split column A into values below or equal to the median (0), and values above the median (1) 7. For each possible value (0 or 1), vi, of A, 8. Add a new tree branch below Root that corresponds to rows of data where A = vi 9. Let Examples(vi) be the subset of examples that have the value vi for A 10. Below this new branch add the subtree id3(data[A==vi], target, columns) 11. Return Root
The algorithm creates only two branches from each node. This will simplify the process of constructing the tree, and make it easier to demonstrate the principles it involves.
The recursive nature of the algorithm comes into play on line 10. Every node in the tree will call the id3() function, and the final tree will be the result of all of these calls.
# A list of columns to potentially split income with
columns = ["age", "workclass", "education_num", "marital_status", "occupation", "relationship", "race", "sex",
"hours_per_week", "native_country"]
def find_best_column(data, target_name, columns):
information_gains = []
for each_col in columns:
information_gains.append(calc_information_gain(preprocessed_income, each_col, "high_income"))
highest_gain = columns[information_gains.index(max(information_gains))]
return highest_gain
income_split = find_best_column(preprocessed_income, "high_income", columns)
print(income_split)
marital_status
# Create the data set that we used in the example on the last screen
data = pd.DataFrame([
[0,20,0],
[0,60,2],
[0,40,1],
[1,25,1],
[1,35,2],
[1,55,1]
])
# Assign column names to the data
data.columns = ["high_income", "age", "marital_status"]
data
high_income | age | marital_status | |
---|---|---|---|
0 | 0 | 20 | 0 |
1 | 0 | 60 | 2 |
2 | 0 | 40 | 1 |
3 | 1 | 25 | 1 |
4 | 1 | 35 | 2 |
5 | 1 | 55 | 1 |
label_1s = []
label_0s = []
def id3(data, target, columns):
unique_targets = pd.unique(data[target])
if len(unique_targets) == 1:
if 0 in unique_targets:
label_0s.append(0)
elif 1 in unique_targets:
label_1s.append(1)
return
best_column = find_best_column(data, target, columns)
column_median = data[best_column].median()
left_split = data[data[best_column] <= column_median]
right_split = data[data[best_column] > column_median]
for split in [left_split, right_split]:
id3(split, target, columns)
id3(data, "high_income", ["age", "marital_status"])
--------------------------------------------------------------------------- RecursionError Traceback (most recent call last) <ipython-input-29-3717d0755180> in <module> 27 id3(split, target, columns) 28 ---> 29 id3(data, "high_income", ["age", "marital_status"]) <ipython-input-29-3717d0755180> in id3(data, target, columns) 25 26 for split in [left_split, right_split]: ---> 27 id3(split, target, columns) 28 29 id3(data, "high_income", ["age", "marital_status"]) <ipython-input-29-3717d0755180> in id3(data, target, columns) 25 26 for split in [left_split, right_split]: ---> 27 id3(split, target, columns) 28 29 id3(data, "high_income", ["age", "marital_status"]) <ipython-input-29-3717d0755180> in id3(data, target, columns) 25 26 for split in [left_split, right_split]: ---> 27 id3(split, target, columns) 28 29 id3(data, "high_income", ["age", "marital_status"]) <ipython-input-29-3717d0755180> in id3(data, target, columns) 25 26 for split in [left_split, right_split]: ---> 27 id3(split, target, columns) 28 29 id3(data, "high_income", ["age", "marital_status"]) <ipython-input-29-3717d0755180> in id3(data, target, columns) 25 26 for split in [left_split, right_split]: ---> 27 id3(split, target, columns) 28 29 id3(data, "high_income", ["age", "marital_status"]) ... last 5 frames repeated, from the frame below ... <ipython-input-29-3717d0755180> in id3(data, target, columns) 25 26 for split in [left_split, right_split]: ---> 27 id3(split, target, columns) 28 29 id3(data, "high_income", ["age", "marital_status"]) RecursionError: maximum recursion depth exceeded while calling a Python object
# Create the data set that we used in the example on the last screen
data = pd.DataFrame([
[0,20,0],
[0,60,2],
[0,40,1],
[1,25,1],
[1,35,2],
[1,55,1]
])
# Assign column names to the data
data.columns = ["high_income", "age", "marital_status"]
label_1s = []
label_0s = []
def id3(data, target, columns):
unique_targets = pd.unique(data[target])
if len(unique_targets) == 1:
if 0 in unique_targets:
label_0s.append(0)
elif 1 in unique_targets:
label_1s.append(1)
return
best_column = find_best_column(data, target, columns)
column_median = data[best_column].median()
left_split = data[data[best_column] <= column_median]
right_split = data[data[best_column] > column_median]
for split in [left_split, right_split]:
id3(split, target, columns)
id3(data, "high_income", ["age", "marital_status"])
--------------------------------------------------------------------------- RecursionError Traceback (most recent call last) <ipython-input-30-1471ff2df903> in <module> 34 35 ---> 36 id3(data, "high_income", ["age", "marital_status"]) <ipython-input-30-1471ff2df903> in id3(data, target, columns) 31 32 for split in [left_split, right_split]: ---> 33 id3(split, target, columns) 34 35 <ipython-input-30-1471ff2df903> in id3(data, target, columns) 31 32 for split in [left_split, right_split]: ---> 33 id3(split, target, columns) 34 35 <ipython-input-30-1471ff2df903> in id3(data, target, columns) 31 32 for split in [left_split, right_split]: ---> 33 id3(split, target, columns) 34 35 <ipython-input-30-1471ff2df903> in id3(data, target, columns) 31 32 for split in [left_split, right_split]: ---> 33 id3(split, target, columns) 34 35 <ipython-input-30-1471ff2df903> in id3(data, target, columns) 31 32 for split in [left_split, right_split]: ---> 33 id3(split, target, columns) 34 35 ... last 5 frames repeated, from the frame below ... <ipython-input-30-1471ff2df903> in id3(data, target, columns) 31 32 for split in [left_split, right_split]: ---> 33 id3(split, target, columns) 34 35 RecursionError: maximum recursion depth exceeded while calling a Python object