# Libraries
library(h2o) # for H2O Machine Learning
library(lime) # for Machine Learning Interpretation
library(mlbench) # for Datasets
# Your lucky seed here ...
n_seed = 12345
data("BostonHousing")
dim(BostonHousing)
## [1] 506 14
head(BostonHousing)
crim | zn | indus | chas | nox | rm | age | dis | rad | tax | ptratio | b | lstat | medv |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0.00632 | 18 | 2.31 | 0 | 0.538 | 6.575 | 65.2 | 4.0900 | 1 | 296 | 15.3 | 396.90 | 4.98 | 24.0 |
0.02731 | 0 | 7.07 | 0 | 0.469 | 6.421 | 78.9 | 4.9671 | 2 | 242 | 17.8 | 396.90 | 9.14 | 21.6 |
0.02729 | 0 | 7.07 | 0 | 0.469 | 7.185 | 61.1 | 4.9671 | 2 | 242 | 17.8 | 392.83 | 4.03 | 34.7 |
0.03237 | 0 | 2.18 | 0 | 0.458 | 6.998 | 45.8 | 6.0622 | 3 | 222 | 18.7 | 394.63 | 2.94 | 33.4 |
0.06905 | 0 | 2.18 | 0 | 0.458 | 7.147 | 54.2 | 6.0622 | 3 | 222 | 18.7 | 396.90 | 5.33 | 36.2 |
0.02985 | 0 | 2.18 | 0 | 0.458 | 6.430 | 58.7 | 6.0622 | 3 | 222 | 18.7 | 394.12 | 5.21 | 28.7 |
target = "medv" # Median House Value
features = setdiff(colnames(BostonHousing), target)
print(features)
## [1] "crim" "zn" "indus" "chas" "nox" "rm" "age"
## [8] "dis" "rad" "tax" "ptratio" "b" "lstat"
# Start a local H2O cluster (JVM)
h2o.init()
##
## H2O is not running yet, starting it now...
##
## Note: In case of errors look at the following log files:
## /var/folders/4z/p7yt7_4n4fj1jlyq6g4qhfbw0000gn/T//RtmpA1VouU/h2o_jofaichow_started_from_r.out
## /var/folders/4z/p7yt7_4n4fj1jlyq6g4qhfbw0000gn/T//RtmpA1VouU/h2o_jofaichow_started_from_r.err
##
##
## Starting H2O JVM and connecting: . Connection successful!
##
## R is connected to the H2O cluster:
## H2O cluster uptime: 3 seconds 81 milliseconds
## H2O cluster timezone: Europe/London
## H2O data parsing timezone: UTC
## H2O cluster version: 3.20.0.2
## H2O cluster version age: 8 days
## H2O cluster name: H2O_started_from_R_jofaichow_njl100
## H2O cluster total nodes: 1
## H2O cluster total memory: 3.56 GB
## H2O cluster total cores: 8
## H2O cluster allowed cores: 8
## H2O cluster healthy: TRUE
## H2O Connection ip: localhost
## H2O Connection port: 54321
## H2O Connection proxy: NA
## H2O Internal Security: FALSE
## H2O API Extensions: XGBoost, Algos, AutoML, Core V3, Core V4
## R Version: R version 3.5.0 (2018-04-23)
h2o.no_progress() # disable progress bar for RMarkdown
h2o.removeAll() # Optional: remove anything from previous session
## [1] 0
# H2O dataframe
h_boston = as.h2o(BostonHousing)
head(BostonHousing)
crim | zn | indus | chas | nox | rm | age | dis | rad | tax | ptratio | b | lstat | medv |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0.00632 | 18 | 2.31 | 0 | 0.538 | 6.575 | 65.2 | 4.0900 | 1 | 296 | 15.3 | 396.90 | 4.98 | 24.0 |
0.02731 | 0 | 7.07 | 0 | 0.469 | 6.421 | 78.9 | 4.9671 | 2 | 242 | 17.8 | 396.90 | 9.14 | 21.6 |
0.02729 | 0 | 7.07 | 0 | 0.469 | 7.185 | 61.1 | 4.9671 | 2 | 242 | 17.8 | 392.83 | 4.03 | 34.7 |
0.03237 | 0 | 2.18 | 0 | 0.458 | 6.998 | 45.8 | 6.0622 | 3 | 222 | 18.7 | 394.63 | 2.94 | 33.4 |
0.06905 | 0 | 2.18 | 0 | 0.458 | 7.147 | 54.2 | 6.0622 | 3 | 222 | 18.7 | 396.90 | 5.33 | 36.2 |
0.02985 | 0 | 2.18 | 0 | 0.458 | 6.430 | 58.7 | 6.0622 | 3 | 222 | 18.7 | 394.12 | 5.21 | 28.7 |
# Split Train/Test
h_split = h2o.splitFrame(h_boston, ratios = 0.75, seed = n_seed)
h_train = h_split[[1]] # 75% for modelling
h_test = h_split[[2]] # 25% for evaluation
# Train a Default H2O GBM model
model_gbm = h2o.gbm(x = features,
y = target,
training_frame = h_train,
model_id = "gbm_default_reg",
seed = n_seed)
print(model_gbm)
## Model Details:
## ==============
##
## H2ORegressionModel: gbm
## Model ID: gbm_default_reg
## Model Summary:
## number_of_trees number_of_internal_trees model_size_in_bytes min_depth
## 1 50 50 11651 5
## max_depth mean_depth min_leaves max_leaves mean_leaves
## 1 5 5.00000 6 21 13.62000
##
##
## H2ORegressionMetrics: gbm
## ** Reported on training data. **
##
## MSE: 2.76525
## RMSE: 1.662904
## MAE: 1.187687
## RMSLE: 0.0793635
## Mean Residual Deviance : 2.76525
# Evaluate performance on test
h2o.performance(model_gbm, newdata = h_test)
## H2ORegressionMetrics: gbm
##
## MSE: 14.39651
## RMSE: 3.794273
## MAE: 2.515851
## RMSLE: 0.1519358
## Mean Residual Deviance : 14.39651
# Train multiple H2O models with H2O AutoML
# Stacked Ensembles will be created from those H2O models
# You tell H2O ...
# 1) how much time you have and/or
# 2) how many models do you want
# Note: H2O deep learning algo on multi-core is stochastic
model_automl = h2o.automl(x = features,
y = target,
training_frame = h_train,
nfolds = 5, # Cross-Validation
max_runtime_secs = 120, # Max time
max_models = 100, # Max no. of models
stopping_metric = "RMSE", # Metric to optimize
project_name = "automl_reg",
exclude_algos = NULL, # If you want to exclude any algo
seed = n_seed)
model_automl@leaderboard
## model_id
## 1 StackedEnsemble_BestOfFamily_0_AutoML_20180624_101350
## 2 StackedEnsemble_AllModels_0_AutoML_20180624_101350
## 3 DeepLearning_grid_0_AutoML_20180624_101350_model_1
## 4 GBM_grid_0_AutoML_20180624_101350_model_1
## 5 GBM_grid_0_AutoML_20180624_101350_model_3
## 6 GBM_grid_0_AutoML_20180624_101350_model_2
## mean_residual_deviance rmse mse mae rmsle
## 1 9.189284 3.031383 9.189284 2.034305 0.1444506
## 2 9.351899 3.058088 9.351899 2.021522 0.1415304
## 3 9.642519 3.105241 9.642519 2.169955 0.1667710
## 4 11.299723 3.361506 11.299723 2.176395 0.1501403
## 5 11.535687 3.396423 11.535687 2.181420 0.1510191
## 6 11.737661 3.426027 11.737661 2.208042 0.1524836
##
## [20 rows x 6 columns]
# H2O: Model Leader
# Best Model (either an individual model or a stacked ensemble)
model_automl@leader
## Model Details:
## ==============
##
## H2ORegressionModel: stackedensemble
## Model ID: StackedEnsemble_BestOfFamily_0_AutoML_20180624_101350
## NULL
##
##
## H2ORegressionMetrics: stackedensemble
## ** Reported on training data. **
##
## MSE: 2.143234
## RMSE: 1.463979
## MAE: 1.01632
## RMSLE: 0.07402433
## Mean Residual Deviance : 2.143234
##
##
## H2ORegressionMetrics: stackedensemble
## ** Reported on validation data. **
##
## MSE: 11.98943
## RMSE: 3.462576
## MAE: 2.144992
## RMSLE: 0.153796
## Mean Residual Deviance : 11.98943
##
##
## H2ORegressionMetrics: stackedensemble
## ** Reported on cross-validation data. **
## ** 5-fold cross-validation on training data (Metrics computed for combined holdout predictions) **
##
## MSE: 9.189284
## RMSE: 3.031383
## MAE: 2.034305
## RMSLE: 0.1444506
## Mean Residual Deviance : 9.189284
# Default GBM Model
h2o.performance(model_gbm, newdata = h_test)
## H2ORegressionMetrics: gbm
##
## MSE: 14.39651
## RMSE: 3.794273
## MAE: 2.515851
## RMSLE: 0.1519358
## Mean Residual Deviance : 14.39651
# Best model from AutoML
h2o.performance(model_automl@leader, newdata = h_test) # lower RMSE = better
## H2ORegressionMetrics: stackedensemble
##
## MSE: 10.22516
## RMSE: 3.19768
## MAE: 2.113674
## RMSLE: 0.13414
## Mean Residual Deviance : 10.22516
yhat_test = h2o.predict(model_automl@leader, h_test)
head(yhat_test)
predict |
---|
35.92196 |
17.99340 |
19.04678 |
19.27645 |
20.35679 |
17.01696 |
h2o.saveModel()
to save model to diskh2o.loadModel()
to re-load modelh2o.download_mojo()
and h2o.download_pojo()
# Save model to disk
h2o.saveModel(object = model_automl@leader,
path = "./models/",
force = TRUE)
explainer
explainer = lime::lime(x = as.data.frame(h_train[, features]),
model = model_automl@leader)
explainer
into explanations
# Extract one sample (change `1` to any row you want)
d_samp = as.data.frame(h_test[1, features])
# Assign a specifc row name (for better visualization)
row.names(d_samp) = "Sample 1"
# Create explanations
explanations = lime::explain(x = d_samp,
explainer = explainer,
n_permutations = 5000,
feature_select = "auto",
n_features = 13) # Look top x features
lime::plot_features(explanations, ncol = 1)
# Sort explanations by feature weight
explanations =
explanations[order(explanations$feature_weight, decreasing = TRUE),]
# Print Table
print(explanations)
## model_type case model_r2 model_intercept model_prediction feature
## 6 regression Sample 1 0.5384601 21.99738 36.9296 rm
## 13 regression Sample 1 0.5384601 21.99738 36.9296 lstat
## 10 regression Sample 1 0.5384601 21.99738 36.9296 tax
## 5 regression Sample 1 0.5384601 21.99738 36.9296 nox
## 12 regression Sample 1 0.5384601 21.99738 36.9296 b
## 11 regression Sample 1 0.5384601 21.99738 36.9296 ptratio
## 1 regression Sample 1 0.5384601 21.99738 36.9296 crim
## 7 regression Sample 1 0.5384601 21.99738 36.9296 age
## 3 regression Sample 1 0.5384601 21.99738 36.9296 indus
## 4 regression Sample 1 0.5384601 21.99738 36.9296 chas
## 2 regression Sample 1 0.5384601 21.99738 36.9296 zn
## 9 regression Sample 1 0.5384601 21.99738 36.9296 rad
## 8 regression Sample 1 0.5384601 21.99738 36.9296 dis
## feature_value feature_weight feature_desc
## 6 6.99800 10.46087230 6.59 < rm
## 13 2.94000 7.50459171 lstat <= 7.23
## 10 222.00000 1.76951337 tax <= 277
## 5 0.45800 0.92592275 0.453 < nox <= 0.538
## 12 394.63000 0.65880103 391 < b <= 396
## 11 18.70000 0.49534271 17.4 < ptratio <= 19.1
## 1 0.03237 -0.05366048 crim <= 0.079
## 7 45.80000 -0.13611625 45.6 < age <= 80.5
## 3 2.18000 -0.19423936 indus <= 5.13
## 4 1.00000 -0.34072116 chas = 0
## 2 0.00000 -0.45866336 zn <= 16.2
## 9 3.00000 -2.11054845 rad <= 4
## 8 6.06220 -3.58886893 5.12 < dis
## data
## 6 0.03237, 0.00000, 2.18000, 1.00000, 0.45800, 6.99800, 45.80000, 6.06220, 3.00000, 222.00000, 18.70000, 394.63000, 2.94000
## 13 0.03237, 0.00000, 2.18000, 1.00000, 0.45800, 6.99800, 45.80000, 6.06220, 3.00000, 222.00000, 18.70000, 394.63000, 2.94000
## 10 0.03237, 0.00000, 2.18000, 1.00000, 0.45800, 6.99800, 45.80000, 6.06220, 3.00000, 222.00000, 18.70000, 394.63000, 2.94000
## 5 0.03237, 0.00000, 2.18000, 1.00000, 0.45800, 6.99800, 45.80000, 6.06220, 3.00000, 222.00000, 18.70000, 394.63000, 2.94000
## 12 0.03237, 0.00000, 2.18000, 1.00000, 0.45800, 6.99800, 45.80000, 6.06220, 3.00000, 222.00000, 18.70000, 394.63000, 2.94000
## 11 0.03237, 0.00000, 2.18000, 1.00000, 0.45800, 6.99800, 45.80000, 6.06220, 3.00000, 222.00000, 18.70000, 394.63000, 2.94000
## 1 0.03237, 0.00000, 2.18000, 1.00000, 0.45800, 6.99800, 45.80000, 6.06220, 3.00000, 222.00000, 18.70000, 394.63000, 2.94000
## 7 0.03237, 0.00000, 2.18000, 1.00000, 0.45800, 6.99800, 45.80000, 6.06220, 3.00000, 222.00000, 18.70000, 394.63000, 2.94000
## 3 0.03237, 0.00000, 2.18000, 1.00000, 0.45800, 6.99800, 45.80000, 6.06220, 3.00000, 222.00000, 18.70000, 394.63000, 2.94000
## 4 0.03237, 0.00000, 2.18000, 1.00000, 0.45800, 6.99800, 45.80000, 6.06220, 3.00000, 222.00000, 18.70000, 394.63000, 2.94000
## 2 0.03237, 0.00000, 2.18000, 1.00000, 0.45800, 6.99800, 45.80000, 6.06220, 3.00000, 222.00000, 18.70000, 394.63000, 2.94000
## 9 0.03237, 0.00000, 2.18000, 1.00000, 0.45800, 6.99800, 45.80000, 6.06220, 3.00000, 222.00000, 18.70000, 394.63000, 2.94000
## 8 0.03237, 0.00000, 2.18000, 1.00000, 0.45800, 6.99800, 45.80000, 6.06220, 3.00000, 222.00000, 18.70000, 394.63000, 2.94000
## prediction
## 6 35.92196
## 13 35.92196
## 10 35.92196
## 5 35.92196
## 12 35.92196
## 11 35.92196
## 1 35.92196
## 7 35.92196
## 3 35.92196
## 4 35.92196
## 2 35.92196
## 9 35.92196
## 8 35.92196
Replace BostonHousing
with your own data. Good luck!
sessionInfo()
## R version 3.5.0 (2018-04-23)
## Platform: x86_64-apple-darwin15.6.0 (64-bit)
## Running under: OS X El Capitan 10.11.6
##
## Matrix products: default
## BLAS: /Library/Frameworks/R.framework/Versions/3.5/Resources/lib/libRblas.0.dylib
## LAPACK: /Library/Frameworks/R.framework/Versions/3.5/Resources/lib/libRlapack.dylib
##
## locale:
## [1] en_GB.UTF-8/en_GB.UTF-8/en_GB.UTF-8/C/en_GB.UTF-8/en_GB.UTF-8
##
## attached base packages:
## [1] stats graphics grDevices utils datasets methods base
##
## other attached packages:
## [1] mlbench_2.1-1 lime_0.4.0 h2o_3.20.0.2
##
## loaded via a namespace (and not attached):
## [1] Rcpp_0.12.17 highr_0.7 later_0.7.3
## [4] compiler_3.5.0 pillar_1.2.3 gower_0.1.2
## [7] plyr_1.8.4 bitops_1.0-6 iterators_1.0.9
## [10] tools_3.5.0 digest_0.6.15 jsonlite_1.5
## [13] evaluate_0.10.1 tibble_1.4.2 gtable_0.2.0
## [16] lattice_0.20-35 rlang_0.2.1 Matrix_1.2-14
## [19] foreach_1.4.4 shiny_1.1.0 magick_1.9
## [22] parallel_3.5.0 yaml_2.1.19 stringr_1.3.1
## [25] knitr_1.20 htmlwidgets_1.2 rprojroot_1.3-2
## [28] grid_3.5.0 glmnet_2.0-16 R6_2.2.2
## [31] rmarkdown_1.10 ggplot2_2.2.1 magrittr_1.5
## [34] shinythemes_1.1.1 promises_1.0.1 backports_1.1.2
## [37] scales_0.5.0 codetools_0.2-15 htmltools_0.3.6
## [40] stringdist_0.9.5.1 assertthat_0.2.0 xtable_1.8-2
## [43] mime_0.5 colorspace_1.3-2 httpuv_1.4.4.1
## [46] labeling_0.3 stringi_1.2.3 RCurl_1.95-4.10
## [49] lazyeval_0.2.1 munsell_0.5.0