In [ ]:
# setup our standard computation environment
import numpy as np, pandas as pd, matplotlib.pyplot as plt, seaborn as sns
%matplotlib inline
sns.set_style('darkgrid')
sns.set_context('poster')
In [ ]:
# set random seed for reproducibility
np.random.seed(12345)
In [ ]:
# some data, based on a real Abie project
df = pd.DataFrame(dict(
    age=
        array([ 19.,  19.,  20.,  22.,  19.,  20.,  20.,  24.,  20.,  20.,  19.,
                20.,  19.,  21.,  21.,  20.,  19.,  19.,  22.,  20.,  20.,  19.,
                18.,  19.,  19.,  19.,  20.,  18.,  19.,  19.,  19.,  22.,  21.,
                20.,  20.,  18.,  19.,  20.,  22.,  18.,  19.,  21.,  22.,  21.,
                26.,  20.,  19.,  20.,  18.,  20.,  20.,  18.,  27.,  19.,  20.,
                18.,  20.,  19.,  20.,  19.,  21.,  19.,  19.,  22.,  18.,  20.,
                20.,  24.,  18.,  19.,  18.,  20.,  19.,  20.,  22.,  19.,  19.,
                18.,  19.,  26.,  22.,  20.,  19.,  19.,  20.,  21.,  21.,  19.,
                22.,  19.,  19.,  20.,  19.,  22.,  22.,  22.,  20.,  20.,  18.,
                20.]),
    level=
        array([u'Undergraduate', u'Undergraduate', u'Undergraduate',
               u'Undergraduate', u'Undergraduate', u'Undergraduate',
               u'Undergraduate', u'Undergraduate', u'Undergraduate',
               u'Undergraduate', u'Undergraduate', u'Undergraduate',
               u'Undergraduate', u'Undergraduate', u'Undergraduate',
               u'Undergraduate', u'Undergraduate', u'Undergraduate',
               u'Undergraduate', u'Undergraduate', u'Undergraduate',
               u'Undergraduate', u'Undergraduate', u'Undergraduate',
               u'Undergraduate', u'Undergraduate', u'Undergraduate',
               u'Undergraduate', u'Undergraduate', u'Undergraduate',
               u'Undergraduate', u'Undergraduate', u'Undergraduate',
               u'Undergraduate', u'Undergraduate', u'Undergraduate',
               u'Undergraduate', u'Undergraduate', u'Undergraduate',
               u'Undergraduate', u'Undergraduate', u'Undergraduate',
               u'Undergraduate', u'Undergraduate', u'Undergraduate',
               u'Undergraduate', u'Undergraduate', u'Undergraduate',
               u'Undergraduate', u'Undergraduate', u'Undergraduate',
               u'Undergraduate', u'Undergraduate', u'Undergraduate',
               u'Undergraduate', u'Undergraduate', u'Undergraduate',
               u'Undergraduate', u'Undergraduate', u'Undergraduate',
               u'Undergraduate', u'Undergraduate', u'Undergraduate',
               u'Undergraduate', u'Undergraduate', u'Undergraduate',
               u'Undergraduate', u'Medical Student', u'Medical Student',
               u'Medical Student', u'Medical Student', u'Medical Student',
               u'Medical Student', u'Medical Student', u'Medical Student',
               u'Medical Student', u'Medical Student', u'Medical Student',
               u'Medical Student', u'Medical Student', u'Medical Student',
               u'Medical Student', u'Medical Student', u'Medical Student',
               u'Medical Student', u'Medical Student', u'Medical Student',
               u'Medical Student', u'Medical Student', u'Medical Student',
               u'Medical Student', u'Medical Student', u'Medical Student',
               u'Medical Student', u'Medical Student', u'Medical Student',
               u'Medical Student', u'Medical Student', u'Medical Student',
               u'Medical Student'], dtype=object),
    meansalt17to23=
        array([ 0.0555556,  0.3333333,  0.7222222,  0.       ,  0.0555556,
                0.0555556,  0.2777778,  0.2222222,  0.3333333,  0.5555556,
                0.2777778,  0.1111111,  0.5555556,  0.5555556,  0.3333333,
                0.5      ,  0.1666667,  0.4444444,  0.1111111,  0.6666667,
                0.6666667,  0.2222222,  0.2222222,  0.0555556,  0.1111111,
                0.2777778,  0.       ,  0.1666667,  0.3333333,  0.4444444,
                0.2777778,  0.1666667,  0.2222222,  0.3333333,  0.0555556,
                0.4444444,  0.5555556,  0.2777778,  0.3888889,  0.2222222,
                0.2777778,  0.3888889,  0.3333333,  0.0555556,  0.4444444,
                0.3333333,  0.5555556,  0.2222222,  0.1666667,  0.1666667,
                0.2222222,  0.2777778,  0.1666667,  0.0555556,  0.6111111,
                0.8333333,  0.2777778,  0.1111111,  0.9444444,  0.3333333,
                0.2777778,  0.2222222,  0.0555556,  0.8888889,  0.0555556,
                0.5      ,  0.       ,  0.3125   ,  0.5625   ,  0.       ,
                0.0625   ,  0.125    ,  0.625    ,  0.0625   ,  0.125    ,
                0.       ,  0.5625   ,  0.125    ,  0.3125   ,  0.8125   ,
                0.4375   ,  0.       ,  0.1875   ,  0.25     ,  0.4375   ,
                0.3125   ,  0.       ,  0.1875   ,  0.0625   ,  0.25     ,
                0.       ,  0.5      ,  0.375    ,  0.0625   ,  0.1875   ,
                0.1875   ,  0.4375   ,  0.375    ,  0.1875   ,  0.3125   ])
    ))

Inspired by the very practical need to do some data mining (aka p-value hacking) of the results of our salt game experiment, I am going to implement my own C4.5-style decision tree algorithm. It will be un-optimized, emphasizing simplicity over speed. And perhaps I will make my students do it, too. But it might be too hard.

Need criteria for node quality to do splitting

In [ ]:
# use negative MSE, so that higher values are better
def neg_mse(example_indices):
    #print example_indices
    return -np.var(df.meansalt17to23[example_indices])

# will have a special one for finding splits that have large difference between treatment and control
In [ ]:
# overall mse impurity
neg_mse(df.index)

Need a way to tell if a variable is numeric or categorical

In [ ]:
# for a given split, eval split quality
# two cases: categorical and numeric
split = 'level'
df[split].dtype
In [ ]:
split = 'age'
df[split].dtype
In [ ]:
# TODO: make this a robust method
def data_type(s):
    if type(s) == int:
        return 'numeric'
    if type(s) == float:
        return 'numeric'
    if type(s) == str or type(s) == unicode:
        return 'categorical'
    if np.dtype(s) == np.dtype('O'):
        return 'categorical'
    elif np.dtype(s) == np.float64:
        return 'numeric'
assert data_type(20) == 'numeric'
assert data_type(20.) == 'numeric'
assert data_type('Undergraduate') == 'categorical'
assert data_type(df['level']) == 'categorical'
assert data_type(df['age']) == 'numeric'
In [ ]:
df.level.unique()

Represent a decision tree as a nested dictionary

In [ ]:
# each node is a dict, with keys for:
#   split : str, optional; feature column that node splits on (node is leaf if not present)
#   thresh : float, optional; if split col is numeric, value to split above and below
#   child_VALS : dicts, with VALS = leq, greater for numeric splits, and VALS = df[split].unique() for categorical
#   examples : rows of pd.DataFrame which pass through this node

# depth 0 tree:
dt_ex0 = {'examples': df.index}

# depth 1 tree:
dt_ex = \
{'examples': df.index,
 'split': 'age',
 'thresh': 25,
 'child_leq': {'examples': df[df.age<=25].index},
 'child_greater': {'examples': df[df.age>25].index},
 }

# depth 2 example:
dt_ex = \
{'examples': df.index,
 'split': 'age',
 'thresh': 25,
 'child_leq': {'examples': df[df.age<=25].index,
               'split': 'level',
               'child_Undergraduate': {'examples':df[(df.age<=25)&(df.level=='Undergraduate')].index},
               'child_Medical Student': {'examples':df[(df.age<=25)&(df.level=='Medical Student')].index},
               'child_Resident': {'examples':df[(df.age<=25)&(df.level=='Resident')].index},
               'child_Attending': {'examples':df[(df.age<=25)&(df.level=='Attending')].index},},
 'child_greater': {'examples': df[df.age>25].index,
                   'split': 'level',
               'child_Undergraduate': {'examples':df[(df.age>25)&(df.level=='Undergraduate')].index},
               'child_Medical Student': {'examples':df[(df.age>25)&(df.level=='Medical Student')].index},
               'child_Resident': {'examples':df[(df.age>25)&(df.level=='Resident')].index},
               'child_Attending': {'examples':df[(df.age>25)&(df.level=='Attending')].index},},
 }

To predict for a row of data, recursively descend tree and use mean of leaf node

In [ ]:
def predict(tree, x):
    split = tree.get('split')
    if split:  # internal node
        if data_type(x[split]) == 'numeric':
            if x[split] <= tree['thresh']:
                return predict(tree['child_leq'], x)
            else:
                return predict(tree['child_greater'], x)
        else: # data_type is 'categorical'
            subtree = tree['child_'+str(x[split])]
            return predict(subtree, x)
    else:  # leaf node
        rows = np.array(tree['examples'])
        return df.meansalt17to23[rows].mean()
predict(dt_ex, {'age':20, 'level':'Undergraduate'})
In [ ]:
for level in df.level.unique():
    print level, predict(dt_ex, {'age':26, 'level':level})

On second thought, it would be better to provide a function for prediction, that could be the mean, or something else

In [ ]:
def mean(rows):
    return df.meansalt17to23[rows].mean()

# to predict for a row of data, recursively descend tree and use mean of leaf node
def predict(tree, x, agg_func):
    """ predict with decision tree `tree` for example x
    tree : nest dict, as described above
    x : example dict or pd.Series
    agg_func : function that takes rows for examples and calculates prediction value
    """
    
    split = tree.get('split')
    if split:  # internal node
        if data_type(x[split]) == 'numeric':
            if x[split] <= tree['thresh']:
                return predict(tree['child_leq'], x, agg_func)
            else:
                return predict(tree['child_greater'], x, agg_func)
        else: # data_type is 'categorical'
            subtree = tree['child_'+str(x[split])]
            return predict(subtree, x, agg_func)
    else:  # leaf node
        rows = np.array(tree['examples'])
        return agg_func(rows)
    
predict(dt_ex, {'age':20, 'level':'Undergraduate'}, mean)
In [ ]:
for level in df.level.unique():
    print level, predict(dt_ex, {'age':26, 'level':level}, mean)
In [ ]:
predict(dt_ex, df.loc[0], mean)

Now to grow a tree, search through all candidate splits, pick best, and recurse

In [ ]:
def fit(X, y, criteria, depth=0, max_depth=np.inf, min_examples=2):
    """ recursively build greedy tree (depth-first)
    
    X : pd.DataFrame of feature vectors (as rows)
    y : pd.Series of labels
    criteria : function that measures quality of a set of labels (e.g. MSE or gini)
    depth : int, current depth of recursion
    max_depth : int, maximum depth of tree
    min_example : int, minimum number of examples for an internal node
    """
   
    # simple cases:
    # not enough examples
    if len(X.index) <= min_examples:
        return {'examples': X.index}
    
    # tree too deep
    if depth >= max_depth:
        return {'examples': X.index}
    
    # hard case: search for best split
    n = float(len(X.index))
    max_quality = -np.inf
    best_split = None
    
    for split in X.columns:
        if data_type(X[split]) == 'numeric':
            for thresh in X[split].unique():
                left_vals = y[X[split] <= thresh]
                right_vals = y[X[split] > thresh]
                quality = len(left_vals) / n * criteria(left_vals.index) \
                            + len(right_vals) / n * criteria(right_vals.index)
                if quality > max_quality:
                    max_quality = quality
                    best_split = split
                    best_thresh = thresh
        else: # split var is categorical
            quality = 0
            for cat in X[split].unique():
                cat_vals = y[X[split] == cat]
                quality += len(cat_vals) / n * criteria(cat_vals.index)
            if quality > max_quality:
                max_quality = quality
                best_split = split

    if not best_split:
        return {'examples': X.index}
    else: # found a split to use, so partition on it and recurse
        tree = {'examples': X.index,
                'split': best_split}
        if data_type(X[best_split]) == 'numeric':
            tree['thresh'] = best_thresh

            rows = X[best_split] <= best_thresh
            tree['child_leq'] = fit(X[rows], y[rows], criteria, depth+1, max_depth, min_examples)

            rows = X[best_split] > best_thresh
            tree['child_greater'] = fit(X[rows], y[rows], criteria, depth+1, max_depth, min_examples)
        
        else: # best split var is categorical
            for cat in X[best_split].unique():
                rows = X[best_split] == cat
                tree['child_'+cat] = fit(X[rows].drop(best_split, axis=1), y[rows], criteria, depth+1, max_depth, min_examples)
    
        return tree

tree = fit(df.filter(['age', 'level']), df.meansalt17to23, neg_mse, max_depth=1)
predict(tree, df.loc[0], mean)

How do we test this?

In [ ]:
# interactively, with a recursive tree printer
def print_tree(tree, depth=0):
    indent = '    '*depth
    if 'split' in tree:
        if 'thresh' in tree:
            print indent + 'if X[%s] <= %.2f:' % (tree['split'], tree['thresh'])
            print_tree(tree['child_leq'], depth+1)
            print indent + 'else:'
            print_tree(tree['child_greater'], depth+1)
        else:
            for k in tree:
                if k.startswith('child_'):
                    print indent + 'if X[%s] == "%s":' % (tree['split'], k.replace('child_', ''))
                    print_tree(tree[k], depth+1)
    else:
        rows = np.array(tree['examples'])
        print indent + 'return agg_func(%d, %d, ...) [%d examples]' % (rows[0], rows[1], len(rows))
                                                          

# depth one, see that it works right
n=3
tree = fit(df.loc[:n, ['age', ]], df.meansalt17to23[:n], neg_mse, max_depth=1)
print_tree(tree)
In [ ]:
df.filter(['age', 'level', 'meansalt17to23']).loc[:n]

Test with replication of class exercise:

In [ ]:
np.random.seed(123456)

x_true = np.linspace(0,10,1000)
y_true = np.cos(x_true)

x_train = np.random.choice(x_true, size=100)
y_train = np.random.normal(np.cos(x_train), .4)

plt.plot(x_true, y_true, label='truth')
plt.plot(x_train, y_train, 's', label='train')
plt.legend()
In [ ]:
X = pd.DataFrame({'x':x_train})
y = pd.Series(y_train)

def neg_mse(rows):
    return -y[rows].var()

def mean(rows):
    return y[rows].mean()


tree = fit(X, y, neg_mse, 0, 25, 25)
print_tree(tree)
In [ ]:
y_pred = [predict(tree, {'x': x}, mean) for x in x_true]
In [ ]:
plt.plot(x_true, y_true, label='truth')
plt.plot(x_train, y_train, 's', label='train')
plt.plot(x_true, y_pred, label='pred')
plt.legend()
In [ ]: