Setup a regression experiment

In [ ]:
import pandas as pd
from sklearn.datasets import load_boston
from sklearn.model_selection import train_test_split

boston = load_boston()
feature_names = list(boston.feature_names)
df = pd.DataFrame(, columns=feature_names)
df["target"] =
# df = df.sample(frac=0.1, random_state=1)
train_cols = df.columns[0:-1]
label = df.columns[-1]
X = df[train_cols]
y = df[label]

seed = 1
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.20, random_state=seed)

Explore the dataset

In [ ]:
from interpret import show
from import Marginal

marginal = Marginal().explain_data(X_train, y_train, name = 'Train Data')

Train the Explainable Boosting Machine (EBM)

In [ ]:
from interpret.glassbox import ExplainableBoostingRegressor, LinearRegression, RegressionTree

ebm = ExplainableBoostingRegressor(random_state=seed), y_train)   #Works on dataframes and numpy arrays

Global Explanations: What the model learned overall

In [ ]:
ebm_global = ebm.explain_global(name='EBM')

Local Explanations: How an individual prediction was made

In [ ]:
ebm_local = ebm.explain_local(X_test[:5], y_test[:5], name='EBM')

Evaluate EBM performance

In [ ]:
from interpret import show
from interpret.perf import RegressionPerf

ebm_perf = RegressionPerf(ebm.predict).explain_perf(X_test, y_test, name='EBM')

Let's test out a few other Explainable Models

In [ ]:
from interpret.glassbox import LinearRegression, RegressionTree

lr = LinearRegression(random_state=seed), y_train)

rt = RegressionTree(random_state=seed), y_train)

Compare performance using the Dashboard

In [ ]:
lr_perf = RegressionPerf(lr.predict).explain_perf(X_test, y_test, name='Linear Regression')
rt_perf = RegressionPerf(rt.predict).explain_perf(X_test, y_test, name='Regression Tree')


Glassbox: All of our models have global and local explanations

In [ ]:
lr_global = lr.explain_global(name='Linear Regression')
rt_global = rt.explain_global(name='Regression Tree')


Dashboard: look at everything at once

In [ ]:
# Do everything in one shot with the InterpretML Dashboard by passing a list into show

show([marginal, lr_global, lr_perf, rt_global, rt_perf, ebm_global, ebm_perf])