# Libraries
library(h2o) # for H2O Machine Learning
library(lime) # for Machine Learning Interpretation
library(mlbench) # for Datasets
# Your lucky seed here ...
n_seed = 12345
data("PimaIndiansDiabetes")
dim(PimaIndiansDiabetes)
## [1] 768 9
head(PimaIndiansDiabetes)
pregnant | glucose | pressure | triceps | insulin | mass | pedigree | age | diabetes |
---|---|---|---|---|---|---|---|---|
6 | 148 | 72 | 35 | 0 | 33.6 | 0.627 | 50 | pos |
1 | 85 | 66 | 29 | 0 | 26.6 | 0.351 | 31 | neg |
8 | 183 | 64 | 0 | 0 | 23.3 | 0.672 | 32 | pos |
1 | 89 | 66 | 23 | 94 | 28.1 | 0.167 | 21 | neg |
0 | 137 | 40 | 35 | 168 | 43.1 | 2.288 | 33 | pos |
5 | 116 | 74 | 0 | 0 | 25.6 | 0.201 | 30 | neg |
target = "diabetes" # Median House Value
features = setdiff(colnames(PimaIndiansDiabetes), target)
print(features)
## [1] "pregnant" "glucose" "pressure" "triceps" "insulin" "mass"
## [7] "pedigree" "age"
# Start a local H2O cluster (JVM)
h2o.init()
##
## H2O is not running yet, starting it now...
##
## Note: In case of errors look at the following log files:
## /var/folders/4z/p7yt7_4n4fj1jlyq6g4qhfbw0000gn/T//RtmpTwIbBb/h2o_jofaichow_started_from_r.out
## /var/folders/4z/p7yt7_4n4fj1jlyq6g4qhfbw0000gn/T//RtmpTwIbBb/h2o_jofaichow_started_from_r.err
##
##
## Starting H2O JVM and connecting: . Connection successful!
##
## R is connected to the H2O cluster:
## H2O cluster uptime: 3 seconds 843 milliseconds
## H2O cluster timezone: Europe/London
## H2O data parsing timezone: UTC
## H2O cluster version: 3.20.0.2
## H2O cluster version age: 8 days
## H2O cluster name: H2O_started_from_R_jofaichow_tkc225
## H2O cluster total nodes: 1
## H2O cluster total memory: 3.56 GB
## H2O cluster total cores: 8
## H2O cluster allowed cores: 8
## H2O cluster healthy: TRUE
## H2O Connection ip: localhost
## H2O Connection port: 54321
## H2O Connection proxy: NA
## H2O Internal Security: FALSE
## H2O API Extensions: XGBoost, Algos, AutoML, Core V3, Core V4
## R Version: R version 3.5.0 (2018-04-23)
h2o.no_progress() # disable progress bar for RMarkdown
h2o.removeAll() # Optional: remove anything from previous session
## [1] 0
# H2O dataframe
h_diabetes = as.h2o(PimaIndiansDiabetes)
head(PimaIndiansDiabetes)
pregnant | glucose | pressure | triceps | insulin | mass | pedigree | age | diabetes |
---|---|---|---|---|---|---|---|---|
6 | 148 | 72 | 35 | 0 | 33.6 | 0.627 | 50 | pos |
1 | 85 | 66 | 29 | 0 | 26.6 | 0.351 | 31 | neg |
8 | 183 | 64 | 0 | 0 | 23.3 | 0.672 | 32 | pos |
1 | 89 | 66 | 23 | 94 | 28.1 | 0.167 | 21 | neg |
0 | 137 | 40 | 35 | 168 | 43.1 | 2.288 | 33 | pos |
5 | 116 | 74 | 0 | 0 | 25.6 | 0.201 | 30 | neg |
# Split Train/Test
h_split = h2o.splitFrame(h_diabetes, ratios = 0.75, seed = n_seed)
h_train = h_split[[1]] # 75% for modelling
h_test = h_split[[2]] # 25% for evaluation
# Train a Default H2O GBM model
model_gbm = h2o.gbm(x = features,
y = target,
training_frame = h_train,
model_id = "gbm_default_class",
seed = n_seed)
print(model_gbm)
## Model Details:
## ==============
##
## H2OBinomialModel: gbm
## Model ID: gbm_default_class
## Model Summary:
## number_of_trees number_of_internal_trees model_size_in_bytes min_depth
## 1 50 50 13586 5
## max_depth mean_depth min_leaves max_leaves mean_leaves
## 1 5 5.00000 8 25 16.62000
##
##
## H2OBinomialMetrics: gbm
## ** Reported on training data. **
##
## MSE: 0.05514127
## RMSE: 0.2348218
## LogLoss: 0.217466
## Mean Per-Class Error: 0.0403526
## AUC: 0.9918796
## Gini: 0.9837591
##
## Confusion Matrix (vertical: actual; across: predicted) for F1-optimal threshold:
## neg pos Error Rate
## neg 361 6 0.016349 =6/367
## pos 13 189 0.064356 =13/202
## Totals 374 195 0.033392 =19/569
##
## Maximum Metrics: Maximum metrics at their respective thresholds
## metric threshold value idx
## 1 max f1 0.483222 0.952141 165
## 2 max f2 0.358389 0.954677 194
## 3 max f0point5 0.483222 0.962322 165
## 4 max accuracy 0.483222 0.966608 165
## 5 max precision 0.972361 1.000000 0
## 6 max recall 0.231660 1.000000 235
## 7 max specificity 0.972361 1.000000 0
## 8 max absolute_mcc 0.483222 0.926852 165
## 9 max min_per_class_accuracy 0.463749 0.945545 170
## 10 max mean_per_class_accuracy 0.483222 0.959647 165
##
## Gains/Lift Table: Extract with `h2o.gainsLift(<model>, <data>)` or `h2o.gainsLift(<model>, valid=<T/F>, xval=<T/F>)`
# Evaluate performance on test
h2o.performance(model_gbm, newdata = h_test)
## H2OBinomialMetrics: gbm
##
## MSE: 0.159823
## RMSE: 0.3997787
## LogLoss: 0.4948851
## Mean Per-Class Error: 0.2300068
## AUC: 0.8238779
## Gini: 0.6477558
##
## Confusion Matrix (vertical: actual; across: predicted) for F1-optimal threshold:
## neg pos Error Rate
## neg 96 37 0.278195 =37/133
## pos 12 54 0.181818 =12/66
## Totals 108 91 0.246231 =49/199
##
## Maximum Metrics: Maximum metrics at their respective thresholds
## metric threshold value idx
## 1 max f1 0.324056 0.687898 90
## 2 max f2 0.188454 0.800000 110
## 3 max f0point5 0.634708 0.691057 44
## 4 max accuracy 0.634708 0.783920 44
## 5 max precision 0.955667 1.000000 0
## 6 max recall 0.013187 1.000000 198
## 7 max specificity 0.955667 1.000000 0
## 8 max absolute_mcc 0.324056 0.510326 90
## 9 max min_per_class_accuracy 0.360682 0.744361 83
## 10 max mean_per_class_accuracy 0.219388 0.770221 102
##
## Gains/Lift Table: Extract with `h2o.gainsLift(<model>, <data>)` or `h2o.gainsLift(<model>, valid=<T/F>, xval=<T/F>)`
# Train multiple H2O models with H2O AutoML
# Stacked Ensembles will be created from those H2O models
# You tell H2O ...
# 1) how much time you have and/or
# 2) how many models do you want
# Note: H2O deep learning algo on multi-core is stochastic
model_automl = h2o.automl(x = features,
y = target,
training_frame = h_train,
nfolds = 5, # Cross-Validation
max_runtime_secs = 120, # Max time
max_models = 100, # Max no. of models
stopping_metric = "AUC", # Metric to optimize
project_name = "automl_class",
exclude_algos = NULL, # If you want to exclude any algo
seed = n_seed)
model_automl@leaderboard
## model_id auc
## 1 StackedEnsemble_BestOfFamily_0_AutoML_20180624_100747 0.8224382
## 2 StackedEnsemble_AllModels_0_AutoML_20180624_100747 0.8191946
## 3 GLM_grid_0_AutoML_20180624_100747_model_0 0.8178992
## 4 GBM_grid_0_AutoML_20180624_100747_model_5 0.8090076
## 5 DeepLearning_grid_0_AutoML_20180624_100747_model_0 0.8044064
## 6 GBM_grid_0_AutoML_20180624_100747_model_3 0.8013493
## logloss mean_per_class_error rmse mse
## 1 0.4947523 0.2464558 0.4042743 0.1634377
## 2 0.4976774 0.2481450 0.4059306 0.1647797
## 3 0.4933979 0.2506322 0.4035811 0.1628777
## 4 0.5107149 0.2639280 0.4122423 0.1699437
## 5 0.6249904 0.2467667 0.4284441 0.1835644
## 6 0.5216803 0.2686743 0.4179093 0.1746482
##
## [14 rows x 6 columns]
# H2O: Model Leader
# Best Model (either an individual model or a stacked ensemble)
model_automl@leader
## Model Details:
## ==============
##
## H2OBinomialModel: stackedensemble
## Model ID: StackedEnsemble_BestOfFamily_0_AutoML_20180624_100747
## NULL
##
##
## H2OBinomialMetrics: stackedensemble
## ** Reported on training data. **
##
## MSE: 0.1192101
## RMSE: 0.3452682
## LogLoss: 0.3866771
## Mean Per-Class Error: 0.1640793
## AUC: 0.9143488
## Gini: 0.8286976
##
## Confusion Matrix (vertical: actual; across: predicted) for F1-optimal threshold:
## neg pos Error Rate
## neg 237 59 0.199324 =59/296
## pos 21 142 0.128834 =21/163
## Totals 258 201 0.174292 =80/459
##
## Maximum Metrics: Maximum metrics at their respective thresholds
## metric threshold value idx
## 1 max f1 0.323387 0.780220 185
## 2 max f2 0.266953 0.860702 213
## 3 max f0point5 0.526851 0.804769 114
## 4 max accuracy 0.526851 0.838780 114
## 5 max precision 0.906104 1.000000 0
## 6 max recall 0.136909 1.000000 301
## 7 max specificity 0.906104 1.000000 0
## 8 max absolute_mcc 0.323387 0.648035 185
## 9 max min_per_class_accuracy 0.354199 0.820946 172
## 10 max mean_per_class_accuracy 0.277362 0.837123 203
##
## Gains/Lift Table: Extract with `h2o.gainsLift(<model>, <data>)` or `h2o.gainsLift(<model>, valid=<T/F>, xval=<T/F>)`
## H2OBinomialMetrics: stackedensemble
## ** Reported on validation data. **
##
## MSE: 0.1602515
## RMSE: 0.4003143
## LogLoss: 0.4913908
## Mean Per-Class Error: 0.2139762
## AUC: 0.8262911
## Gini: 0.6525822
##
## Confusion Matrix (vertical: actual; across: predicted) for F1-optimal threshold:
## neg pos Error Rate
## neg 57 14 0.197183 =14/71
## pos 9 30 0.230769 =9/39
## Totals 66 44 0.209091 =23/110
##
## Maximum Metrics: Maximum metrics at their respective thresholds
## metric threshold value idx
## 1 max f1 0.352618 0.722892 43
## 2 max f2 0.112625 0.805085 79
## 3 max f0point5 0.470776 0.718563 31
## 4 max accuracy 0.470776 0.790909 31
## 5 max precision 0.821827 0.875000 7
## 6 max recall 0.090741 1.000000 97
## 7 max specificity 0.895897 0.985915 0
## 8 max absolute_mcc 0.352618 0.558593 43
## 9 max min_per_class_accuracy 0.352618 0.769231 43
## 10 max mean_per_class_accuracy 0.352618 0.786024 43
##
## Gains/Lift Table: Extract with `h2o.gainsLift(<model>, <data>)` or `h2o.gainsLift(<model>, valid=<T/F>, xval=<T/F>)`
## H2OBinomialMetrics: stackedensemble
## ** Reported on cross-validation data. **
## ** 5-fold cross-validation on training data (Metrics computed for combined holdout predictions) **
##
## MSE: 0.1634377
## RMSE: 0.4042743
## LogLoss: 0.4947523
## Mean Per-Class Error: 0.2464558
## AUC: 0.8224382
## Gini: 0.6448765
##
## Confusion Matrix (vertical: actual; across: predicted) for F1-optimal threshold:
## neg pos Error Rate
## neg 230 66 0.222973 =66/296
## pos 44 119 0.269939 =44/163
## Totals 274 185 0.239651 =110/459
##
## Maximum Metrics: Maximum metrics at their respective thresholds
## metric threshold value idx
## 1 max f1 0.359805 0.683908 174
## 2 max f2 0.135984 0.799595 312
## 3 max f0point5 0.566088 0.678514 108
## 4 max accuracy 0.566088 0.762527 108
## 5 max precision 0.896981 1.000000 0
## 6 max recall 0.093384 1.000000 372
## 7 max specificity 0.896981 1.000000 0
## 8 max absolute_mcc 0.369328 0.494940 169
## 9 max min_per_class_accuracy 0.325507 0.739865 187
## 10 max mean_per_class_accuracy 0.359805 0.753544 174
##
## Gains/Lift Table: Extract with `h2o.gainsLift(<model>, <data>)` or `h2o.gainsLift(<model>, valid=<T/F>, xval=<T/F>)`
# Default GBM Model
h2o.auc(h2o.performance(model_gbm, newdata = h_test))
## [1] 0.8238779
# Best model from AutoML
h2o.auc(h2o.performance(model_automl@leader, newdata = h_test)) # higher AUC = better
## [1] 0.8480292
yhat_test = h2o.predict(model_automl@leader, h_test)
head(yhat_test)
predict | neg | pos |
---|---|---|
neg | 0.9166085 | 0.0833915 |
pos | 0.3595645 | 0.6404355 |
neg | 0.9013221 | 0.0986779 |
neg | 0.6867013 | 0.3132987 |
pos | 0.4748795 | 0.5251205 |
neg | 0.7095436 | 0.2904564 |
h2o.saveModel()
to save model to diskh2o.loadModel()
to re-load modelh2o.download_mojo()
and h2o.download_pojo()
# Save model to disk
h2o.saveModel(object = model_automl@leader,
path = "./models/",
force = TRUE)
explainer
explainer = lime::lime(x = as.data.frame(h_train[, features]),
model = model_automl@leader)
explainer
into explanations
# Extract one sample (change `2` to any row you want)
d_samp = as.data.frame(h_test[2, features])
# Assign a specifc row name (for better visualization)
row.names(d_samp) = "Sample 2"
# Create explanations
explanations = lime::explain(x = d_samp,
explainer = explainer,
n_permutations = 5000,
feature_select = "auto",
n_labels = 1, # Binary classification
n_features = 13) # Look top x features
lime::plot_features(explanations, ncol = 1)
# Sort explanations by feature weight
explanations =
explanations[order(explanations$feature_weight, decreasing = TRUE),]
# Print Table
print(explanations)
## model_type case label label_prob model_r2 model_intercept
## 1 classification Sample 2 pos 0.6404355 0.1215617 0.3788985
## 3 classification Sample 2 pos 0.6404355 0.1215617 0.3788985
## 6 classification Sample 2 pos 0.6404355 0.1215617 0.3788985
## 5 classification Sample 2 pos 0.6404355 0.1215617 0.3788985
## 8 classification Sample 2 pos 0.6404355 0.1215617 0.3788985
## 4 classification Sample 2 pos 0.6404355 0.1215617 0.3788985
## 2 classification Sample 2 pos 0.6404355 0.1215617 0.3788985
## 7 classification Sample 2 pos 0.6404355 0.1215617 0.3788985
## model_prediction feature feature_value feature_weight
## 1 0.371247 pregnant 10.000 0.138958368
## 3 0.371247 pressure 0.000 0.055554606
## 6 0.371247 mass 35.300 0.038400074
## 5 0.371247 insulin 0.000 0.005134898
## 8 0.371247 age 29.000 -0.013361248
## 4 0.371247 triceps 0.000 -0.058823453
## 2 0.371247 glucose 115.000 -0.083493004
## 7 0.371247 pedigree 0.134 -0.090021739
## feature_desc
## 1 6 < pregnant
## 3 pressure <= 62
## 6 32.2 < mass <= 36.8
## 5 insulin <= 40
## 8 24 < age <= 29
## 4 triceps <= 23
## 2 99 < glucose <= 116
## 7 pedigree <= 0.245
## data
## 1 10.000, 115.000, 0.000, 0.000, 0.000, 35.300, 0.134, 29.000
## 3 10.000, 115.000, 0.000, 0.000, 0.000, 35.300, 0.134, 29.000
## 6 10.000, 115.000, 0.000, 0.000, 0.000, 35.300, 0.134, 29.000
## 5 10.000, 115.000, 0.000, 0.000, 0.000, 35.300, 0.134, 29.000
## 8 10.000, 115.000, 0.000, 0.000, 0.000, 35.300, 0.134, 29.000
## 4 10.000, 115.000, 0.000, 0.000, 0.000, 35.300, 0.134, 29.000
## 2 10.000, 115.000, 0.000, 0.000, 0.000, 35.300, 0.134, 29.000
## 7 10.000, 115.000, 0.000, 0.000, 0.000, 35.300, 0.134, 29.000
## prediction
## 1 0.3595645, 0.6404355
## 3 0.3595645, 0.6404355
## 6 0.3595645, 0.6404355
## 5 0.3595645, 0.6404355
## 8 0.3595645, 0.6404355
## 4 0.3595645, 0.6404355
## 2 0.3595645, 0.6404355
## 7 0.3595645, 0.6404355
Replace PimaIndiansDiabetes
with your own data. Good luck!
sessionInfo()
## R version 3.5.0 (2018-04-23)
## Platform: x86_64-apple-darwin15.6.0 (64-bit)
## Running under: OS X El Capitan 10.11.6
##
## Matrix products: default
## BLAS: /Library/Frameworks/R.framework/Versions/3.5/Resources/lib/libRblas.0.dylib
## LAPACK: /Library/Frameworks/R.framework/Versions/3.5/Resources/lib/libRlapack.dylib
##
## locale:
## [1] en_GB.UTF-8/en_GB.UTF-8/en_GB.UTF-8/C/en_GB.UTF-8/en_GB.UTF-8
##
## attached base packages:
## [1] stats graphics grDevices utils datasets methods base
##
## other attached packages:
## [1] mlbench_2.1-1 lime_0.4.0 h2o_3.20.0.2
##
## loaded via a namespace (and not attached):
## [1] Rcpp_0.12.17 highr_0.7 later_0.7.3
## [4] compiler_3.5.0 pillar_1.2.3 gower_0.1.2
## [7] plyr_1.8.4 bitops_1.0-6 iterators_1.0.9
## [10] tools_3.5.0 digest_0.6.15 jsonlite_1.5
## [13] evaluate_0.10.1 tibble_1.4.2 gtable_0.2.0
## [16] lattice_0.20-35 rlang_0.2.1 Matrix_1.2-14
## [19] foreach_1.4.4 shiny_1.1.0 magick_1.9
## [22] parallel_3.5.0 yaml_2.1.19 stringr_1.3.1
## [25] knitr_1.20 htmlwidgets_1.2 rprojroot_1.3-2
## [28] grid_3.5.0 glmnet_2.0-16 R6_2.2.2
## [31] rmarkdown_1.10 ggplot2_2.2.1 magrittr_1.5
## [34] shinythemes_1.1.1 promises_1.0.1 backports_1.1.2
## [37] scales_0.5.0 codetools_0.2-15 htmltools_0.3.6
## [40] stringdist_0.9.5.1 assertthat_0.2.0 xtable_1.8-2
## [43] mime_0.5 colorspace_1.3-2 httpuv_1.4.4.1
## [46] labeling_0.3 stringi_1.2.3 RCurl_1.95-4.10
## [49] lazyeval_0.2.1 munsell_0.5.0