Copyright 2017 - 2020 Patrick Hall and the H2O.ai team
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.
DISCLAIMER: This notebook is not legal compliance advice.
Decision trees and decision tree ensembles are some of the most popular machine learning models used in commercial practice. They can train and make predictions on data containing character values and missing values - both common in large commercial data stores. Single decision trees are easily represented as directed graphs, which can drastically increase their interpretability and transparency. Decision tree ensembles (i.e., random forests and gradient boosting machines (GBMs)), can be used to increase the accuracy and stability of single decision tree models, but are far less intepretable than single trees. These characteristics of decision trees will be leveraged here to increase transparency and accountability in complex, nonlinear, machine learning models.
This notebook starts by training a GBM on the UCI credit card default data using the popular open source library, h2o. A single decision tree surrogate model will then be trained on the original UCI credit card default data and the predictions from the h2o GBM, to create an approximate flow chart for the GBM's global decision-making processes. A technique known as leave-one-covariate/column-out (LOCO) will then be used to generate local explanations for any row-wise prediction made by the GBM model. Finally, local explanations are ensembled together from multiple similar models to increase explanation stability.
Note: As of the h2o 3.24 "Yates" release, Shapley values are supported in h2o and LOCO is no longer recommended. To see Shapley values for an h2o GBM in action please see: https://github.com/jphall663/interpretable_machine_learning_with_python/blob/master/dia.ipynb.
In general, NumPy and Pandas will be used for data manipulation purposes and h2o will be used for modeling tasks.
# imports
# h2o Python API with specific classes
import h2o
from h2o.estimators.gbm import H2OGradientBoostingEstimator # for GBM
from h2o.estimators.random_forest import H2ORandomForestEstimator # for single tree
from h2o.backend import H2OLocalServer # for plotting local tree in-notebook
import numpy as np # array, vector, matrix calculations
import pandas as pd # DataFrame handling
# system packages for calling external graphviz processes
import os
import re
import subprocess
# in-notebook display
from IPython.display import Image
from IPython.display import display
%matplotlib inline
H2o is both a library and a server. The machine learning algorithms in the library take advantage of the multithreaded and distributed architecture provided by the server to train machine learning algorithms extremely efficiently. The API for the library was imported above in cell 1, but the server still needs to be started.
h2o.init(max_mem_size='2G') # start h2o
h2o.remove_all() # remove any existing data structures from h2o memory
Checking whether there is an H2O instance running at http://localhost:54321 ..... not found. Attempting to start a local H2O server... Java Version: openjdk version "1.8.0_232"; OpenJDK Runtime Environment (build 1.8.0_232-8u232-b09-0ubuntu1~16.04.1-b09); OpenJDK 64-Bit Server VM (build 25.232-b09, mixed mode) Starting server from /home/patrickh/workspace/interpretable_machine_learning_with_python/env_iml/lib/python3.6/site-packages/h2o/backend/bin/h2o.jar Ice root: /tmp/tmpjpy5faje JVM stdout: /tmp/tmpjpy5faje/h2o_patrickh_started_from_python.out JVM stderr: /tmp/tmpjpy5faje/h2o_patrickh_started_from_python.err Server is running at http://127.0.0.1:54321 Connecting to H2O server at http://127.0.0.1:54321 ... successful. Warning: Your H2O cluster version is too old (4 months and 14 days)! Please download and install the latest version from http://h2o.ai/download/
H2O cluster uptime: | 01 secs |
H2O cluster timezone: | America/New_York |
H2O data parsing timezone: | UTC |
H2O cluster version: | 3.26.0.3 |
H2O cluster version age: | 4 months and 14 days !!! |
H2O cluster name: | H2O_from_python_patrickh_6414b6 |
H2O cluster total nodes: | 1 |
H2O cluster free memory: | 1.778 Gb |
H2O cluster total cores: | 8 |
H2O cluster allowed cores: | 8 |
H2O cluster status: | accepting new members, healthy |
H2O connection url: | http://127.0.0.1:54321 |
H2O connection proxy: | None |
H2O internal security: | False |
H2O API Extensions: | Amazon S3, XGBoost, Algos, AutoML, Core V3, Core V4 |
Python version: | 3.6.3 final |
UCI credit card default data: https://archive.ics.uci.edu/ml/datasets/default+of+credit+card+clients
The UCI credit card default data contains demographic and payment information about credit card customers in Taiwan in the year 2005. The data set contains 23 input variables:
LIMIT_BAL
: Amount of given credit (NT dollar)SEX
: 1 = male; 2 = femaleEDUCATION
: 1 = graduate school; 2 = university; 3 = high school; 4 = othersMARRIAGE
: 1 = married; 2 = single; 3 = othersAGE
: Age in yearsPAY_0
, PAY_2
- PAY_6
: History of past payment; PAY_0
= the repayment status in September, 2005; PAY_2
= the repayment status in August, 2005; ...; PAY_6
= the repayment status in April, 2005. The measurement scale for the repayment status is: -1 = pay duly; 1 = payment delay for one month; 2 = payment delay for two months; ...; 8 = payment delay for eight months; 9 = payment delay for nine months and above.BILL_AMT1
- BILL_AMT6
: Amount of bill statement (NT dollar). BILL_AMNT1
= amount of bill statement in September, 2005; BILL_AMT2
= amount of bill statement in August, 2005; ...; BILL_AMT6
= amount of bill statement in April, 2005.PAY_AMT1
- PAY_AMT6
: Amount of previous payment (NT dollar). PAY_AMT1
= amount paid in September, 2005; PAY_AMT2
= amount paid in August, 2005; ...; PAY_AMT6
= amount paid in April, 2005.These 23 input variables are used to predict the target variable, whether or not a customer defaulted on their credit card bill in late 2005.
Because h2o accepts both numeric and character inputs, some variables will be recoded into more transparent character values.
The credit card default data is available as an .xls
file. Pandas reads .xls
files automatically, so it's used to load the credit card default data and give the prediction target a shorter name: DEFAULT_NEXT_MONTH
.
# import XLS file
path = 'default_of_credit_card_clients.xls'
data = pd.read_excel(path,
skiprows=1)
# remove spaces from target column name
data = data.rename(columns={'default payment next month': 'DEFAULT_NEXT_MONTH'})
The shorthand name y
is assigned to the prediction target. X
is assigned to all other input variables in the credit card default data except the row indentifier, ID
.
# assign target and inputs for GBM
y = 'DEFAULT_NEXT_MONTH'
X = [name for name in data.columns if name not in [y, 'ID']]
print('y =', y)
print('X =', X)
y = DEFAULT_NEXT_MONTH X = ['LIMIT_BAL', 'SEX', 'EDUCATION', 'MARRIAGE', 'AGE', 'PAY_0', 'PAY_2', 'PAY_3', 'PAY_4', 'PAY_5', 'PAY_6', 'BILL_AMT1', 'BILL_AMT2', 'BILL_AMT3', 'BILL_AMT4', 'BILL_AMT5', 'BILL_AMT6', 'PAY_AMT1', 'PAY_AMT2', 'PAY_AMT3', 'PAY_AMT4', 'PAY_AMT5', 'PAY_AMT6']
This simple function maps longer, more understandable character string values from the UCI credit card default data dictionary to the original integer values of the input variables found in the dataset. These character values can be used directly in h2o decision tree models, and the function returns the original Pandas DataFrame as an h2o object, an H2OFrame. H2o models cannot run on Pandas DataFrames. They require H2OFrames.
def recode_cc_data(frame):
""" Recodes numeric categorical variables into categorical character variables
with more transparent values.
Args:
frame: Pandas DataFrame version of UCI credit card default data.
Returns:
H2OFrame with recoded values.
"""
# define recoded values
sex_dict = {1:'male', 2:'female'}
education_dict = {0:'other', 1:'graduate school', 2:'university', 3:'high school',
4:'other', 5:'other', 6:'other'}
marriage_dict = {0:'other', 1:'married', 2:'single', 3:'divorced'}
pay_dict = {-2:'no consumption', -1:'pay duly', 0:'use of revolving credit', 1:'1 month delay',
2:'2 month delay', 3:'3 month delay', 4:'4 month delay', 5:'5 month delay', 6:'6 month delay',
7:'7 month delay', 8:'8 month delay', 9:'9+ month delay'}
# recode values using Pandas apply() and anonymous function
frame['SEX'] = frame['SEX'].apply(lambda i: sex_dict[i])
frame['EDUCATION'] = frame['EDUCATION'].apply(lambda i: education_dict[i])
frame['MARRIAGE'] = frame['MARRIAGE'].apply(lambda i: marriage_dict[i])
for name in frame.columns:
if name in ['PAY_0', 'PAY_2', 'PAY_3', 'PAY_4', 'PAY_5', 'PAY_6']:
frame[name] = frame[name].apply(lambda i: pay_dict[i])
return h2o.H2OFrame(frame)
data = recode_cc_data(data)
Parse progress: |█████████████████████████████████████████████████████████| 100%
In h2o, a numeric variable can be treated as numeric or categorical. The target variable DEFAULT_NEXT_MONTH
takes on values of 0
or 1
. To ensure this numeric variable is treated as a categorical variable, the asfactor()
function is used to explicitly declare that it is a categorical variable.
data[y] = data[y].asfactor()
The h2o describe()
function displays a brief description of the credit card default data. For the categorical input variables LIMIT_BAL
, SEX
, EDUCATION
, MARRIAGE
, and PAY_0
-PAY_6
, the new character values created above in cell 5 are visible. Basic descriptive statistics are displayed for numeric inputs. Also, it's easy to see there are no missing values in this dataset, which will be an important consideration for calculating LOCO values in section 5 and 6.
data[X + [y]].describe()
Rows:30000 Cols:24
LIMIT_BAL | SEX | EDUCATION | MARRIAGE | AGE | PAY_0 | PAY_2 | PAY_3 | PAY_4 | PAY_5 | PAY_6 | BILL_AMT1 | BILL_AMT2 | BILL_AMT3 | BILL_AMT4 | BILL_AMT5 | BILL_AMT6 | PAY_AMT1 | PAY_AMT2 | PAY_AMT3 | PAY_AMT4 | PAY_AMT5 | PAY_AMT6 | DEFAULT_NEXT_MONTH | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
type | int | enum | enum | enum | int | enum | enum | enum | enum | enum | enum | int | int | int | int | int | int | int | int | int | int | int | int | enum |
mins | 10000.0 | 21.0 | -165580.0 | -69777.0 | -157264.0 | -170000.0 | -81334.0 | -339603.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | ||||||||||
mean | 167484.32266666688 | 35.48549999999994 | 51223.33090000009 | 49179.07516666668 | 47013.15479999971 | 43262.9489666666 | 40311.40096666653 | 38871.76039999991 | 5663.580500000014 | 5921.16350000001 | 5225.681500000005 | 4826.076866666661 | 4799.387633333302 | 5215.502566666664 | ||||||||||
maxs | 1000000.0 | 79.0 | 964511.0 | 983931.0 | 1664089.0 | 891586.0 | 927171.0 | 961664.0 | 873552.0 | 1684259.0 | 896040.0 | 621000.0 | 426529.0 | 528666.0 | ||||||||||
sigma | 129747.66156720225 | 9.21790406809016 | 73635.86057552959 | 71173.76878252836 | 69349.38742703681 | 64332.85613391641 | 60797.1557702648 | 59554.10753674574 | 16563.280354025763 | 23040.870402057226 | 17606.961469803115 | 15666.159744031993 | 15278.305679144793 | 17777.465775435332 | ||||||||||
zeros | 0 | 0 | 2008 | 2506 | 2870 | 3195 | 3506 | 4020 | 5249 | 5396 | 5968 | 6408 | 6703 | 7173 | ||||||||||
missing | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
0 | 20000.0 | female | university | married | 24.0 | 2 month delay | 2 month delay | pay duly | pay duly | no consumption | no consumption | 3913.0 | 3102.0 | 689.0 | 0.0 | 0.0 | 0.0 | 0.0 | 689.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1 |
1 | 120000.0 | female | university | single | 26.0 | pay duly | 2 month delay | use of revolving credit | use of revolving credit | use of revolving credit | 2 month delay | 2682.0 | 1725.0 | 2682.0 | 3272.0 | 3455.0 | 3261.0 | 0.0 | 1000.0 | 1000.0 | 1000.0 | 0.0 | 2000.0 | 1 |
2 | 90000.0 | female | university | single | 34.0 | use of revolving credit | use of revolving credit | use of revolving credit | use of revolving credit | use of revolving credit | use of revolving credit | 29239.0 | 14027.0 | 13559.0 | 14331.0 | 14948.0 | 15549.0 | 1518.0 | 1500.0 | 1000.0 | 1000.0 | 1000.0 | 5000.0 | 0 |
3 | 50000.0 | female | university | married | 37.0 | use of revolving credit | use of revolving credit | use of revolving credit | use of revolving credit | use of revolving credit | use of revolving credit | 46990.0 | 48233.0 | 49291.0 | 28314.0 | 28959.0 | 29547.0 | 2000.0 | 2019.0 | 1200.0 | 1100.0 | 1069.0 | 1000.0 | 0 |
4 | 50000.0 | male | university | married | 57.0 | pay duly | use of revolving credit | pay duly | use of revolving credit | use of revolving credit | use of revolving credit | 8617.0 | 5670.0 | 35835.0 | 20940.0 | 19146.0 | 19131.0 | 2000.0 | 36681.0 | 10000.0 | 9000.0 | 689.0 | 679.0 | 0 |
5 | 50000.0 | male | graduate school | single | 37.0 | use of revolving credit | use of revolving credit | use of revolving credit | use of revolving credit | use of revolving credit | use of revolving credit | 64400.0 | 57069.0 | 57608.0 | 19394.0 | 19619.0 | 20024.0 | 2500.0 | 1815.0 | 657.0 | 1000.0 | 1000.0 | 800.0 | 0 |
6 | 500000.0 | male | graduate school | single | 29.0 | use of revolving credit | use of revolving credit | use of revolving credit | use of revolving credit | use of revolving credit | use of revolving credit | 367965.0 | 412023.0 | 445007.0 | 542653.0 | 483003.0 | 473944.0 | 55000.0 | 40000.0 | 38000.0 | 20239.0 | 13750.0 | 13770.0 | 0 |
7 | 100000.0 | female | university | single | 23.0 | use of revolving credit | pay duly | pay duly | use of revolving credit | use of revolving credit | pay duly | 11876.0 | 380.0 | 601.0 | 221.0 | -159.0 | 567.0 | 380.0 | 601.0 | 0.0 | 581.0 | 1687.0 | 1542.0 | 0 |
8 | 140000.0 | female | high school | married | 28.0 | use of revolving credit | use of revolving credit | 2 month delay | use of revolving credit | use of revolving credit | use of revolving credit | 11285.0 | 14096.0 | 12108.0 | 12211.0 | 11793.0 | 3719.0 | 3329.0 | 0.0 | 432.0 | 1000.0 | 1000.0 | 1000.0 | 0 |
9 | 20000.0 | male | high school | single | 35.0 | no consumption | no consumption | no consumption | no consumption | pay duly | pay duly | 0.0 | 0.0 | 0.0 | 0.0 | 13007.0 | 13912.0 | 0.0 | 0.0 | 0.0 | 13007.0 | 1122.0 | 0.0 | 0 |
The credit card default data is split into training and test sets to monitor and prevent overtraining. Reproducibility is also an important factor in creating trustworthy models, and randomly splitting datasets can introduce randomness in model predictions and other results. A random seed is used here to ensure the data split is reproducible.
# split into training and validation
train, test = data.split_frame([0.7], seed=12345)
# summarize split
print('Train data rows = %d, columns = %d' % (train.shape[0], train.shape[1]))
print('Test data rows = %d, columns = %d' % (test.shape[0], test.shape[1]))
Train data rows = 21060, columns = 25 Test data rows = 8940, columns = 25
Many tuning parameters must be specified to train a GBM using h2o. Typically a grid search would be performed to identify the best parameters for a given modeling task using the H2OGridSearch
class. For brevity's sake, a previously-discovered set of good tuning parameters are specified here. Because gradient boosting methods typically resample training data, an additional random seed is also specified for the h2o GBM using the seed
parameter to create reproducible predictions, error rates, and variable importance values. To avoid overfitting, the stopping_rounds
parameter is used to stop the training process after the test error fails to decrease for 5 iterations.
The balance_classes
parameter ensures the positive and negative classes of the target variable are seen by the model in equal proportions during training. This can be very important for the LOCO calculations in section 5 and 6 for unbalanced data. From experiments across several data sets, explanations for rows with a majority class label for the target variable (e.g., 0) generated by LOCO are more likely to match those generated by another popular explanatory technique, LIME, when the target class is rebalanced during training. balance_classes
is commented below because the row explained in this notebook has a minority class label (e.g., 1).
# initialize GBM model
model = H2OGradientBoostingEstimator(ntrees=150, # maximum 150 trees in GBM
max_depth=4, # trees can have maximum depth of 4
sample_rate=0.9, # use 90% of rows in each iteration (tree)
col_sample_rate=0.9, # use 90% of variables in each iteration (tree)
#balance_classes=True, # sample to balance 0/1 distribution of target - can help LOCO
stopping_rounds=5, # stop if validation error does not decrease for 5 iterations (trees)
score_tree_interval=1, # for reproducibility, set higher for bigger data
seed=12345) # for reproducibility
# train a GBM model
model.train(y=y, x=X, training_frame=train, validation_frame=test)
# print AUC
print('GBM Test AUC = %.2f' % model.auc(valid=True))
# uncomment to see model details
# print(model)
gbm Model Build progress: |███████████████████████████████████████████████| 100% GBM Test AUC = 0.78
During training, the h2o GBM aggregates the improvement in error caused by each split in each decision tree across all the decision trees in the ensemble classifier. These values are attributed to the input variable used in each split and give an indication of the contribution each input variable makes toward the model's predictions. The variable importance ranking should be parsimonious with human domain knowledge and reasonable expectations. In this case, a customer's most recent payment behavior, PAY_0
, is by far the most important variable followed by their second most recent payment, PAY_2
, and third most recent payment, PAY_3
, behavior. This result is well-aligned with business practices in credit lending: people who miss their most recent payments are likely to default soon.
model.varimp_plot()
A surrogate model is a simple model that is used to explain a complex model. One of the original references for surrogate models is available here: https://papers.nips.cc/paper/1152-extracting-tree-structured-representations-of-trained-networks.pdf. In this example, a single decision tree will be trained on the original inputs and predictions of the h2o GBM model and the tree will be visualized using special functionality in h2o and GraphViz. The variable importance, interactions, and decision paths displayed in the directed graph of the trained decision tree surrogate model are then assumed to be indicative of the internal mechanisms of the more complex GBM model, creating an approximate, overall flowchart for the GBM. There are few mathematical guarantees that the simple surrogate model is highly representative of the more complex GBM, but a recent preprint article has put forward ideas on strenghthening the theoretical relationship between surrogate models and more complex models: https://arxiv.org/pdf/1705.08504.pdf. Since surrogate models alone do not gaurantee accurate transparency, they will be used along with GBM variable importance and LOCO to build a cohesive narrative about the mechansims within the GBM. Because many currently-available explanatory techniques are approximate, it is recommended that users employ several different explanatory techniques and trust only consisent results across techniques. Also, as of h2o 3.24, Shapley values are supported for h2o GBM. Use them instead of LOCO for any high-stakes application.
To train a surrogate model, the predictions and original inputs of the complex model to be explained need to be in the same dataset. The test data is used here to see how the model behaves on holdout data, which should be closer to its behavior on new data than analyzing the surrogate model for the training inputs and predictions.
# cbind predictions to training frame
# give them a nice name
yhat = 'p_DEFAULT_NEXT_MONTH'
preds1 = test['ID'].cbind(model.predict(test).drop(['predict', 'p0']))
preds1.columns = ['ID', yhat]
test_yhat = test.cbind(preds1[yhat])
gbm prediction progress: |████████████████████████████████████████████████| 100%
A single decision tree is trained on the test inputs and predictions. To simulate a single decision tree in h2o, the H2ORandomForestEstimator
class is used, but only one tree is trained instead of a forest of decision trees. Setting the mtry
parameter to -2
tells the H2ORandomForestEstimator
to consider all variables in all splits of a tree, instead of considering a random subset of columns. It is also recommended to set a random seed for reproducibility and to set max_depth
to a lower number, say less than 6, so that the surrogate model will not become overly complex and hard to explain and understand. Once the tree is trained, a model optimized java object (MOJO) representation of the tree is saved. H2o provides a way to visualize the trained tree in detail using the MOJO and Graphviz.
model_id = 'dt_surrogate_mojo' # gives MOJO artifact a recognizable name
# initialize single tree surrogate model
surrogate = H2ORandomForestEstimator(ntrees=1, # use only one tree
sample_rate=1, # use all rows in that tree
mtries=-2, # use all columns in that tree
max_depth=3, # shallow trees are easier to understand
seed=12345, # random seed for reproducibility
model_id=model_id) # gives MOJO artifact a recognizable name
# train single tree surrogate model
surrogate.train(x=X, y=yhat, training_frame=test_yhat)
# persist MOJO (compiled, representation of trained model)
# from which to generate plot of surrogate
mojo_path = surrogate.download_mojo(path='.')
print('Generated MOJO path:\n', mojo_path)
drf Model Build progress: |███████████████████████████████████████████████| 100% Generated MOJO path: /home/patrickh/workspace/interpretable_machine_learning_with_python/dt_surrogate_mojo.zip
GraphViz is an open source graph visualization tool. It is freely available from this url: http://www.graphviz.org/. To plot the trained decision tree surrogate model, a special h2o class, PrintMojo
, is executed against the MOJO to create a GraphViz dot file representation of the tree.
# title for plot
title = 'Credit Card Default Decision Tree Surrogate'
# locate h2o jar
hs = H2OLocalServer()
h2o_jar_path = hs._find_jar()
print('Discovered H2O jar path:\n', h2o_jar_path)
# construct command line call to generate graphviz version of
# surrogate tree see for more information:
# http://docs.h2o.ai/h2o/latest-stable/h2o-genmodel/javadoc/index.html
gv_file_name = model_id + '.gv'
gv_args = str('-cp ' + h2o_jar_path +
' hex.genmodel.tools.PrintMojo --tree 0 -i '
+ mojo_path + ' -o').split()
gv_args.insert(0, 'java')
gv_args.append(gv_file_name)
if title is not None:
gv_args = gv_args + ['--title', title]
# call
print()
print('Calling external process ...')
print(' '.join(gv_args))
# if the line below is failing for you, try instead:
# _ = subprocess.call(gv_args, shell=True)
_ = subprocess.call(gv_args)
Discovered H2O jar path: /home/patrickh/workspace/interpretable_machine_learning_with_python/env_iml/lib/python3.6/site-packages/h2o/backend/bin/h2o.jar Calling external process ... java -cp /home/patrickh/workspace/interpretable_machine_learning_with_python/env_iml/lib/python3.6/site-packages/h2o/backend/bin/h2o.jar hex.genmodel.tools.PrintMojo --tree 0 -i /home/patrickh/workspace/interpretable_machine_learning_with_python/dt_surrogate_mojo.zip -o dt_surrogate_mojo.gv --title Credit Card Default Decision Tree Surrogate
Then a GraphViz command line tool is used to create a static PNG image from the dot file ...
# construct call to generate PNG from
# graphviz representation of the tree
png_file_name = model_id + '.png'
png_args = str('dot -Tpng ' + gv_file_name + ' -o ' + png_file_name)
png_args = png_args.split()
# call
print('Calling external process ...')
print(' '.join(png_args))
# if the line below is failing for you, try instead:
# _ = subprocess.call(png_args, shell=True)
_ = subprocess.call(png_args)
Calling external process ... dot -Tpng dt_surrogate_mojo.gv -o dt_surrogate_mojo.png
... and the image is displayed in the notebook.
# display in-notebook
display(Image((png_file_name)))
The displayed tree is comparable with the global GBM variable importance. A simple heuristic rule for variable importance in a decision tree relates to the depth and frequency at which a variable is split on in a tree: variables used higher in the tree and more frequently in the tree are more important. Most of the variables pictured in this tree also appear as highly important in the GBM variable importance plot. In both cases, PAY_0
is appearing as crucially important, with other payment behavior variables following close behind. The surrogate decision tree enables users to understand and confirm not only what input variables are important, but also how their values contribute to model decisions. For instance, to fall into the lowest probability of default leaf node in the surrogate decision tree a customer must make their first and second payments in a timely fashion and then pay more than 1515.5 New Tiawanese Dollars for their fifth payment. Conversely, customers who miss their first, fifth, and third payments fall into the highest probability of default leaf node of the surrogate decision tree. It is also imperative to compare these results to domain knowledge and reasonable expectations. In this case, the global explanatory methods applied thus far tell a consisent and reasonable story about the GBM's behavior. If this was not so, steps should be taken to either reconcile or remove inconsistencies and unreasonable prediction behavior.
Now that a solid understanding of global model behavior has been attained, local behavior for any given row of data and prediction can be analyzed and validated using LOCO. The LOCO method presented here is adapted from Distribution-Free Predictive Inference for Regression by Jing Lei et al., http://www.stat.cmu.edu/~ryantibs/papers/conformal.pdf. Here the local contribution of an input variable to a prediction for a single row of data is estimated by rescoring the GBM on that row one time for each input variable, each time leaving out one input variable (e.g., "covariate") by setting it to missing, and then subtracting the new score from the original score. By default, h2o scores missing data in decision trees by running them through the majority decision path. This means LOCO will be a numeric measure of how different the local contribution of an input variable is from the most common local contribution of that variable in the model. This variant of LOCO differs from the original method, in which one input variable is dropped from the model and the model is retrained without that variable. For nonlinear models, nonlinear dependencies can allow variables to nearly completely replace one another when a variable is dropped and the model is retrained. Hence, the approach of injecting missing values is used to estimate local contributions of input variables for nonlinear models here, as opposed to dropping a variable and retraining the model.
To implement LOCO, GBM model predicitions are calculated once for the test data and then again for each input variable, setting the entire input variable column to missing. Once the prediction without the variable is found for every row of data in the test set, that column vector of predictions on corrupted data can be subtracted from the column vector of predictions on the original, non-corrupted data to estimate the local contribution of that variable for each prediction in the test data. For better local accuracy and explainability, LOCO contributions are scaled such that contributions for each prediction plus the overall average of DEFAULT_NEXT_MONTH
always sum to the model predictions.
h2o.no_progress() # turn off h2o gratuitous progress bars
# create set of original predictions and row ID
preds2 = test['ID'].cbind(model.predict(test).drop(['predict', 'p0']))
preds2.columns = ['ID', yhat]
# calculate LOCO for each variable
print('Calculating LOCO contributions ...')
for k, i in enumerate(X):
# train and predict with x_i set to missing
test_loco = h2o.deep_copy(test, 'test_loco')
test_loco[i] = np.nan
preds_loco = model.predict(test_loco).drop(['predict','p0'])
# create a new, named column for the LOCO prediction
preds_loco.columns = [i]
preds2 = preds2.cbind(preds_loco)
# subtract the LOCO prediction from the original prediction
preds2[i] = preds2[yhat] - preds2[i]
# update progress
print('LOCO Progress: ' + i + ' (' + str(k+1) + '/' + str(len(X)) + ') ...')
# scale contributions to sum to yhat - y_0
print('\nScaling contributions ...')
y_0 = test[y].mean()[0]
preds2_pd = preds2.as_data_frame()
pred_ = preds2_pd[yhat]
scaler = (pred_ - y_0) / preds2_pd[X].sum(axis=1)
preds2_pd[X] = preds2_pd[X].multiply(scaler, axis=0)
print('Done.')
preds2_pd.head()
Calculating LOCO contributions ... LOCO Progress: LIMIT_BAL (1/23) ... LOCO Progress: SEX (2/23) ... LOCO Progress: EDUCATION (3/23) ... LOCO Progress: MARRIAGE (4/23) ... LOCO Progress: AGE (5/23) ... LOCO Progress: PAY_0 (6/23) ... LOCO Progress: PAY_2 (7/23) ... LOCO Progress: PAY_3 (8/23) ... LOCO Progress: PAY_4 (9/23) ... LOCO Progress: PAY_5 (10/23) ... LOCO Progress: PAY_6 (11/23) ... LOCO Progress: BILL_AMT1 (12/23) ... LOCO Progress: BILL_AMT2 (13/23) ... LOCO Progress: BILL_AMT3 (14/23) ... LOCO Progress: BILL_AMT4 (15/23) ... LOCO Progress: BILL_AMT5 (16/23) ... LOCO Progress: BILL_AMT6 (17/23) ... LOCO Progress: PAY_AMT1 (18/23) ... LOCO Progress: PAY_AMT2 (19/23) ... LOCO Progress: PAY_AMT3 (20/23) ... LOCO Progress: PAY_AMT4 (21/23) ... LOCO Progress: PAY_AMT5 (22/23) ... LOCO Progress: PAY_AMT6 (23/23) ... Scaling contributions ... Done.
ID | p_DEFAULT_NEXT_MONTH | LIMIT_BAL | SEX | EDUCATION | MARRIAGE | AGE | PAY_0 | PAY_2 | PAY_3 | ... | BILL_AMT3 | BILL_AMT4 | BILL_AMT5 | BILL_AMT6 | PAY_AMT1 | PAY_AMT2 | PAY_AMT3 | PAY_AMT4 | PAY_AMT5 | PAY_AMT6 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 4 | 0.144991 | -0.079758 | -0.000000 | -0.00000 | -0.009143 | -0.001728 | -0.000000 | -0.000000 | -0.000000 | ... | -0.002113 | 0.005768 | -0.000000 | -0.000000 | -0.000000 | -0.000000 | -0.000000 | -0.000000 | -0.000000 | -0.005340 |
1 | 8 | 0.128193 | -0.020007 | -0.000000 | -0.00000 | -0.000000 | -0.000000 | 0.011403 | -0.000000 | 0.045036 | ... | 0.010062 | -0.000000 | -0.059467 | 0.015406 | -0.028304 | -0.036129 | -0.067713 | -0.000000 | 0.057314 | -0.000000 |
2 | 10 | 0.179911 | -0.024094 | -0.002945 | -0.00000 | -0.000000 | -0.009850 | 0.003778 | -0.000000 | 0.005310 | ... | -0.000000 | -0.000000 | -0.000000 | -0.000000 | -0.005015 | -0.003077 | -0.009657 | 0.019491 | -0.000000 | -0.004973 |
3 | 16 | 0.325205 | 0.012617 | 0.000000 | 0.00000 | 0.004043 | 0.000000 | 0.027090 | 0.055267 | 0.000000 | ... | 0.000000 | -0.003265 | 0.000000 | 0.000000 | 0.005403 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
4 | 17 | 0.408821 | 0.033188 | 0.000000 | 0.00186 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.051523 | ... | -0.001966 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.001247 | 0.000000 | 0.008397 |
5 rows × 25 columns
The numeric LOCO values in each column are an estimate of how much each variable contributed to each prediction. LOCO can indicate how a variable and its values were weighted in any given decision by the model. These values are crucially important for machine learning interpretability and are related to "local feature importance", "reason codes", or "turn-down codes." The latter phrases are borrowed from credit scoring. Credit lenders in the U.S. must provide reasons for automatically rejecting a credit application. Reason codes can be easily extracted from LOCO local variable contribution values by simply ranking the variables that played the largest role in any given decision.
The function below finds and returns the row indices for the minimum, the maximum, and the deciles of one column in terms of another, in this case the model predictions (p_DEFAULT_NEXT_MONTH
) and the row identifier (ID
), respectively. These indices are used as a starting point for finding potentially interesting predictions. Outlying predictions found through residual analysis is another group of potentially interesting local predictions to analyze with LOCO.
def get_percentile_dict(yhat, id_, frame):
""" Returns the minimum, maximum, and percentiles of a column, yhat,
as the indices based on another column id_.
Args:
yhat: Column in which to find percentiles.
id_: Id column that stores indices for percentiles of yhat.
frame: H2OFrame containing yhat and id_.
Returns:
Dictionary of percentile values and index column values.
"""
# convert to Pandas and sort
sort_df = preds2_pd.copy(deep=True)
sort_df.sort_values(yhat, inplace=True)
sort_df.reset_index(inplace=True)
# find top and bottom percentiles
percentiles_dict = {}
percentiles_dict[0] = sort_df.loc[0, id_]
percentiles_dict[99] = sort_df.loc[sort_df.shape[0]-1, id_]
inc = sort_df.shape[0]//10
# find 10th-90th percentiles
for i in range(1, 10):
percentiles_dict[i * 10] = sort_df.loc[i * inc, id_]
return percentiles_dict
# display percentiles dictionary
# ID values for rows
# from lowest prediction
# to highest prediction
percentile_dict = get_percentile_dict(yhat, 'ID', preds2_pd)
percentile_dict
{0: 28716, 99: 29116, 10: 8942, 20: 28257, 30: 4074, 40: 13411, 50: 16633, 60: 2402, 70: 19769, 80: 25069, 90: 21372}
Investigating customers with very high or low predicted probabilities to determine if their local explanations justify their extreme predictions is typically a productive exercise in boundary testing, model debugging, and validation. Reason codes are generated for the customer with the highest probability of default in the test data set below in cell 18, but LOCO can create local explanations for any or all rows in the training or test datasets, and on new data.
# select single customer
# convert to Pandas
# drop prediction and row ID
risky_loco = preds2_pd[preds2_pd['ID'] == int(percentile_dict[99])].drop(['ID', yhat], axis=1)
# transpose into column vector and sort
risky_loco = risky_loco.T.sort_values(by=8674, ascending=False)[:5]
# plot
_ = risky_loco.plot(kind='bar',
title='Top Five Reason Codes for a Risky Customer\n',
legend=False)
For the customer in the test dataset that the GBM predicts as most likely to default, the most important input variables in the prediction are, in descending order, PAY_0
, PAY_6
, PAY_3
, PAY_5
, and AGE
.
The local contributions for this customer appear reasonable, especially when considering her payment information. Her most recent payment was 3 months late and her payment for 6 months previous was 4 months late, so it's logical that these would weigh heavily into the model's prediction for default for this customer.
test_yhat[test_yhat['ID'] == int(percentile_dict[99]), :] # helps understand reason codes
ID | LIMIT_BAL | SEX | EDUCATION | MARRIAGE | AGE | PAY_0 | PAY_2 | PAY_3 | PAY_4 | PAY_5 | PAY_6 | BILL_AMT1 | BILL_AMT2 | BILL_AMT3 | BILL_AMT4 | BILL_AMT5 | BILL_AMT6 | PAY_AMT1 | PAY_AMT2 | PAY_AMT3 | PAY_AMT4 | PAY_AMT5 | PAY_AMT6 | DEFAULT_NEXT_MONTH | p_DEFAULT_NEXT_MONTH |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
29116 | 20000 | female | university | married | 59 | 3 month delay | 2 month delay | 3 month delay | 2 month delay | 2 month delay | 4 month delay | 8803 | 11137 | 10672 | 11201 | 12721 | 11946 | 2800 | 0 | 1000 | 2000 | 0 | 0 | 1 | 0.895285 |
To generate reason codes for the model's decision, the locally important variable and its value are used together. If this customer was denied future credit based on this model and data, the top five LOCO-based reason codes for the automated decision would be:
(Of course, in many places, variables like AGE
and SEX
cannot, and should not, be used in credit lending or other high-stakes decisions. For a slightly more careful treatment of GBM in a fair lending context, see: https://github.com/jphall663/interpretable_machine_learning_with_python/blob/master/dia.ipynb)
Just like predictions from high variance, nonlinear models, explanations derived from machine learning models can be unstable. One general way to decrease variance is to ensemble the results of many models. The last section of this notebook puts forward a simple approach to creating ensemble explanations.
To create ensemble explanations, several accurate models are trained. The models and their predictions on the test data are stored in Python lists.
n_models = 10 # select number of models
# lists for holding models and predictions
models = []
pred_frames = []
for i in range(0, n_models):
# initialize and store models
models.append(H2OGradientBoostingEstimator(ntrees=150,
max_depth=4,
sample_rate=0.9 - ((i + 1)*0.01), # perturb sample rate
col_sample_rate=0.9 - ((i + 1)*0.01), # perturb column sample rate
#balance_classes=True, # sample to balance 0/1 distribution of target - helps LOCO
stopping_rounds=5, # stop if validation error does not decrease for 5 iterations (trees)
seed=i + 1)) # new random seed for each model
# train models
models[i].train(y=y, x=X, training_frame=train, validation_frame=test)
# store predictions
pred_frames.append(test['ID'].cbind(models[i].predict(test).drop(['predict','p0'])))
pred_frames[i].columns = ['ID', yhat]
# update progress
print('Training Progress: model %d/%d, AUC = %.4f ...' % (i + 1, n_models, models[i].auc(valid=True)))
print('Done.')
Training Progress: model 1/10, AUC = 0.7813 ... Training Progress: model 2/10, AUC = 0.7803 ... Training Progress: model 3/10, AUC = 0.7787 ... Training Progress: model 4/10, AUC = 0.7826 ... Training Progress: model 5/10, AUC = 0.7804 ... Training Progress: model 6/10, AUC = 0.7800 ... Training Progress: model 7/10, AUC = 0.7802 ... Training Progress: model 8/10, AUC = 0.7799 ... Training Progress: model 9/10, AUC = 0.7796 ... Training Progress: model 10/10, AUC = 0.7811 ... Done.
LOCO is calculated on the test data for each model, each input, and each row of data in the test set using the stored models and predictions.
# for each new model ...
for k, model in enumerate(models):
# calculate LOCO for each input variable
for i in X:
# train and predict with Xi set to missing
test_loco = h2o.deep_copy(test, 'test_loco')
test_loco[i] = np.nan
preds_loco = model.predict(test_loco).drop(['predict','p0'])
# create a new, named column for the LOCO prediction
preds_loco.columns = [i]
pred_frames[k] = pred_frames[k].cbind(preds_loco)
# subtract the LOCO prediction from the original prediction
pred_frames[k][i] = pred_frames[k][yhat] - pred_frames[k][i]
# update progress
print('LOCO Progress: model %d/%d ...' % (k + 1, n_models))
print('Done.')
LOCO Progress: model 1/10 ... LOCO Progress: model 2/10 ... LOCO Progress: model 3/10 ... LOCO Progress: model 4/10 ... LOCO Progress: model 5/10 ... LOCO Progress: model 6/10 ... LOCO Progress: model 7/10 ... LOCO Progress: model 8/10 ... LOCO Progress: model 9/10 ... LOCO Progress: model 10/10 ... Done.
To create ensemble explanations for a single row, the LOCO values for each variable in the row are averaged across all models. Single-model and mean LOCO values for the most risky person in the test set are displayed below. Notice that even slight changes in model specifications can result in different explanations. For example, the local contribution of PAY_0
for the riskiest customer ranges from 0.13 to 0.23 across the 10 models in the table below.
# holds predictions for a specific row
risky_loco_frames = []
# column names for Pandas DataFrame of combined LOCO prediction
col_names = ['Loco ' + str(i) for i in range(1, n_models + 1)]
# for each new model ...
for i in range(0, n_models):
# collect LOCO for that model and a specific row
# as a column vector in a Pandas DataFrame
preds = pred_frames[i]
risky_loco_frames.append(preds[preds['ID'] == int(percentile_dict[99]), :] # row for risky person
.as_data_frame() # convert to Pandas
.drop(['ID', yhat], axis=1) # drop predictions and row ID
.T) # Transpose into column vector
# bind LOCO for each row as column vectors
# into the same Pandas DataFrame
loco_ensemble = pd.concat(risky_loco_frames, axis=1)
# update column names
loco_ensemble.columns = col_names
# mean local importance across models
loco_ensemble['Mean Local Importance'] = loco_ensemble.mean(axis=1)
# scale contribs
scaler = (test_yhat[test_yhat['ID'] == int(percentile_dict[99]), yhat] - y_0) /\
(loco_ensemble['Mean Local Importance'].sum())
loco_ensemble['Scaled Mean Local Importance'] = loco_ensemble['Mean Local Importance'] * scaler[0, 0]
# std deviation
loco_ensemble['Std. Dev. Local Importance'] = loco_ensemble\
.drop('Scaled Mean Local Importance', axis=1)\
.std(axis=1)
# display
loco_ensemble
Loco 1 | Loco 2 | Loco 3 | Loco 4 | Loco 5 | Loco 6 | Loco 7 | Loco 8 | Loco 9 | Loco 10 | Mean Local Importance | Scaled Mean Local Importance | Std. Dev. Local Importance | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|
LIMIT_BAL | 0.013205 | 0.011040 | 0.012483 | 0.001345 | 0.015150 | -0.006688 | -0.002158 | -0.009428 | -0.002005 | -0.005503 | 0.002744 | 0.004116 | 0.008836 |
SEX | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
EDUCATION | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 |
MARRIAGE | 0.051186 | 0.023304 | 0.049574 | 0.044072 | 0.034987 | 0.083513 | 0.053444 | 0.034809 | 0.051611 | 0.097880 | 0.052438 | 0.078652 | 0.021392 |
AGE | 0.022646 | 0.084028 | 0.015365 | 0.019450 | 0.047277 | 0.008411 | 0.019311 | 0.019194 | 0.064969 | 0.047981 | 0.034863 | 0.052292 | 0.023673 |
PAY_0 | 0.199285 | 0.136971 | 0.155368 | 0.123913 | 0.103977 | 0.094563 | 0.129808 | 0.166523 | 0.163303 | 0.118950 | 0.139266 | 0.208887 | 0.030281 |
PAY_2 | 0.002721 | 0.023228 | 0.029913 | 0.064136 | 0.047839 | 0.003387 | 0.044296 | 0.002752 | 0.021887 | 0.028216 | 0.026837 | 0.040254 | 0.019742 |
PAY_3 | 0.068494 | 0.093103 | 0.047630 | 0.046408 | 0.044800 | 0.016914 | 0.038656 | 0.046357 | 0.088166 | 0.054055 | 0.054458 | 0.081683 | 0.021809 |
PAY_4 | 0.030388 | 0.053298 | 0.039892 | -0.009098 | 0.021445 | 0.022628 | 0.043056 | 0.038197 | 0.063759 | 0.021841 | 0.032541 | 0.048808 | 0.019174 |
PAY_5 | 0.064508 | 0.051064 | 0.033604 | 0.069855 | 0.024448 | 0.038036 | 0.031212 | 0.020676 | 0.063064 | 0.072998 | 0.046947 | 0.070416 | 0.018681 |
PAY_6 | 0.030094 | 0.033554 | 0.027220 | -0.023089 | 0.005430 | 0.005204 | 0.022503 | 0.029180 | 0.028095 | 0.030959 | 0.018915 | 0.028371 | 0.017000 |
BILL_AMT1 | 0.000933 | 0.036190 | 0.002133 | 0.025375 | 0.001071 | 0.004628 | 0.018029 | 0.004658 | 0.021524 | 0.027100 | 0.014164 | 0.021245 | 0.012328 |
BILL_AMT2 | 0.008804 | 0.001631 | 0.000000 | 0.000000 | 0.031767 | 0.000000 | 0.008615 | 0.000000 | 0.000000 | 0.000000 | 0.005082 | 0.007622 | 0.009515 |
BILL_AMT3 | 0.000000 | -0.001008 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | -0.001648 | 0.000000 | -0.000266 | -0.000398 | 0.000550 |
BILL_AMT4 | -0.001370 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | -0.010070 | 0.000000 | -0.005586 | 0.000000 | -0.001738 | -0.001876 | -0.002814 | 0.003198 |
BILL_AMT5 | 0.000000 | 0.000000 | 0.000000 | 0.003339 | 0.000000 | 0.000000 | 0.000000 | 0.004020 | -0.010965 | 0.009204 | 0.000560 | 0.000840 | 0.004787 |
BILL_AMT6 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.006478 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000648 | 0.000972 | 0.001943 |
PAY_AMT1 | 0.000000 | 0.004307 | 0.000000 | 0.000000 | 0.000000 | 0.003919 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000823 | 0.001234 | 0.001648 |
PAY_AMT2 | 0.000000 | 0.001541 | 0.000000 | -0.001350 | 0.004503 | 0.002450 | 0.000000 | 0.001922 | 0.000000 | 0.000985 | 0.001005 | 0.001508 | 0.001582 |
PAY_AMT3 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.005040 | 0.000504 | 0.000756 | 0.001512 |
PAY_AMT4 | 0.000000 | 0.000000 | 0.000000 | -0.006528 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | -0.000653 | -0.000979 | 0.001959 |
PAY_AMT5 | 0.000000 | 0.014605 | 0.016906 | 0.016336 | 0.001440 | 0.004975 | 0.029793 | 0.045810 | 0.019439 | 0.047491 | 0.019679 | 0.029517 | 0.015936 |
PAY_AMT6 | -0.000925 | 0.000000 | 0.000000 | -0.009342 | -0.000800 | -0.000556 | 0.000000 | 0.000000 | 0.000000 | 0.000000 | -0.001162 | -0.001743 | 0.002749 |
Taking mean explanations across multiple models leads to reason codes somewhat different from the reason codes produced by a single model. Mean reason codes may be more stable, they represent explanations from several models, and they may take practicioners a step closer to using machine learning models to make inferential conclusions about phenomena represented in the training or test data, instead of simply providing an approximate explanation of a single model's decision processes.
risky_mean_loco = loco_ensemble['Mean Local Importance'].sort_values(ascending=False)[:5]
_ = risky_mean_loco.plot(kind='bar',
title='Top Five Reason Codes for a Risky Customer\n',
color='b',
legend=False)
After using h2o, it's typically best to shut it down. However, before doing so, users should ensure that they have saved any h2o data structures, such as models and H2OFrames, or scoring artifacts, such as POJOs and MOJOs.
# be careful, this can erase your work!
h2o.cluster().shutdown(prompt=True)
Are you sure you want to shutdown the H2O instance running at http://127.0.0.1:54321 (Y/N)? y H2O session _sid_ae25 closed.
In this notebook, a complex GBM classifier was trained to predict credit card defaults and explained at a global scale with a decision tree surrogate model and explained at a local scale with LOCO. An ensemble LOCO approach was also introduced to stabilize approximate explanations. The decision tree surrogate creates an overall approximate flowchart for the GBM's decision processes and LOCO can be used to create reason codes for each model prediction. All of these techniques enhance the transparency of the complex model, which in turn enables greater accountability for the model's predictions. These techniques should generalize well for many types of business and research problems, enabling you to train a complex GBM model and explain it to your colleagues, bosses, and potentially, external regulators.