sl3
and Writing Custom sl3
Learners¶sl3
¶We begin by illustrating a simple execution of the Super Learner algorithm using the SMOCC data and default algorithms. Start by loading the necessary packages:
library(usethis)
usethis::create_project(".")
library(here)
library(tidyverse)
library(sl3)
library(knitr)
library(R6)
# prediction data set
chspred <- read_csv(here("data", "chspred.csv"))
head(chspred)
We begin by illustrating the "default" functionality of the Super Learner algorithm (as implemented in sl3
). Using the chspred
data, we are interested in predicting myocardial infarcation (mi
) using the available covariate data.
chspred_task <- make_sl3_Task(
data = chspred,
outcome = "hdl",
covariates = colnames(chspred)[!(colnames(chspred) %in% "hdl")]
)
chspred_task
For the sake of computational expediency, we will initially consider only a simple library of algorithms: a fast main effects GLM, an unadjusted (i.e., intercept) model, and a random forest. Later, we will look at how these algorithms are constructed for useage with sl3
. We'll use nonnegative least squares to fit the meta-learning step.
lrn1 <- Lrnr_mean$new()
lrn2 <- Lrnr_glm_fast$new()
sl_lrn <- Lrnr_sl$new(learners = list(lrn1, lrn2),
metalearner = Lrnr_nnls$new())
chspred_sl <- sl_lrn$train(chspred_task)
chspred_sl_pred <- chspred_sl$predict()
head(chspred_sl_pred)
sl_mse <- mean((chspred$hdl - chspred_sl_pred)^2)
sl_mse
We can also obtain predictions on a new observation:
sl3
Learners¶This guide describes the process of implementing a learner class for a new
machine learning algorithm. By writing a learner class for your favorite machine
learning algorithm, you will be able to use it in all the places you could
otherwise use any other sl3
learners, including Pipeline
s, Stack
s, and
Super Learner. We have done our best to streamline the process of creating new
sl3
learners.
Before diving into defining a new learner, it will likely be helpful to read
some background material. If you haven't already read it, the "Modern Machine
Learning in R" vignette is a
good introduction to the sl3
package and it's underlying architecture. The
R6
documentation will help you understand how R6
classes are defined. In
addition, the help files for sl3_Task
and
Lrnr_base
are good resources for how those objects can be
used. If you're interested in defining learners that fit sub-learners, reading
the documentation of the
delayed
package will be helpful.
In the following sections, we introduce and review a template for a new sl3
learner, describing the sections that can be used to define your new learner.
This is followed by a discussion of the important task of documenting and
testing your new learner. Finally, we conclude by explaining how you can add
your learner to sl3
so that others may make use of it.
sl3
provides a template of a learner for use in defining new learners. You can
make a copy of the template to work on by invoking write_learner_template
:
#write_learner_template("path/to/write/Learner_template.R")
The template has comments indicating where details specific to the learner you're trying to implement should be filled in. In the next section, we will discuss those details further.
At the top of the template, we define an object Lrnr_template
and set
classname = "Lrnr_template"
. You should modify these to match the name of your
new learner, which should also match the name of the corresponding R file. Note
that the name should be prefixed by Lrnr_
and use
snake_case
.
public$initialize
¶This function defines the constructor for your learner, and it stores the
arguments (if any) provided when a user calls
make_learner(Lrnr_your_learner, ...)
. You can also provide default parameter
values, just as the template does with param_1 = "default_1"
, and
param_2 = "default_2"
. All parameters used by your newly defined learners
should have defaults whenever possible. This will allow users to use your
learner without having to figure out what reasonable parameter values might be.
Parameter values should be documented; see the section below on
documentation for details.
public$special_function
s¶You can of course define functions for things only your learner can do. These
should be public functions like the special_function
defined in the example.
These should be documented; see the section below on documentation for details.
private$.properties
¶This field defines properties supported by your learner. This may include
different outcome types that are supported, offsets and weights, amongst many
other possibilities. To see a list of all properties supported/used by at least
one learner, you may invoke sl3_list_properties
:
sl3_list_properties()
private$.required_packages
¶This field defines other R packages required for your learner to work properly. These will be loaded when an object of your new learner class is initialized.
If you've used sl3
before, you may have noticed that while users are
instructed to use learner$train
, learner$predict
, and learner$chain
, to
train, generate predictions, and generate a chained task for a given learner
object, respectively, the template does not implement these methods. Instead,
the template implements private methods called .train
, .predict
, and
.chain
. The specifics of these methods are explained below; however, it is
helpful to first understand how the two sets of methods are related. At the risk
of complicating things further, it is worth noting that there is actually a
third set of methods (learner$base_train
, learner$base_predict
, and
learner$base_chain
) of which you may not be aware.
So, what happens when a user calls learner$train
? That method generates a
delayed
object using the delayed_learner_train
function, and then computes
that delayed object. In turn, delayed_learner_train
defines a delayed
computation that calls base_train
, a user-facing function that can be used to
train tasks without using the facilities of the delayed
package. base_train
validates the user input, and in turn calls private$.train
. When
private$.train
returns a fit_object
, base_train
takes that fit object,
generates a learner fit object, and returns it to the user.
Each call to learner$train
involves three separate training methods:
learner$train
-- trains a learner in a manner that can be
parallelized using delayed
, which calls ...learner$base_train
that validates user input, and which
calls ...private$.train
, which does the actual work of fitting the
learner and returning the fit object.The logic in the user-facing learner$train
and learner$base_train
is defined
in the Lrnr_base
base class and is shared across all learners. As such, these
methods need not be reimplemented in individual learners. By contrast,
private$.train
contains the behavior that is specific to each individual
learner and should be reimplemented at the level of each individual learner.
Since learner$base_train
does not use delayed
, it may be helpful to use it
when debugging the training code in a new learner. The program flow used for
prediction and chaining is analogous.
private$.train
¶This is the main training function, which takes in a task and returns a
fit_object
that contains all information needed to generate predictions. The
fit object should not contain more data than is absolutely necessary, as
including excess information will create needless inefficiencies. Many learner
functions (like glm
) store one or more copies of their training data -- this
uses unnecessary memory and will hurt learner performance for large sample
sizes. Thus, these copies of the data should be removed from the fit object
before it is returned. You may make use of true_obj_size
to estimate the size
of your fit_object
. For most learners, fit_object
size should not grow
linearly with training sample size. If it does, and this is unexpected, please
try to reduce the size of the fit_object
.
Most of the time, the learner you are implementing will be fit using a function
that already exists elsewhere. We've built some tools to facilitate passing
parameter values directly to such functions. The private$.train
function in
the template uses a common pattern: it builds up an argument list starting with
the parameter values and using data from the task, it then uses call_with_args
to call my_ml_fun
with that argument list. It's not required that learners use
this pattern, but it will be helpful in the common case where the learner is
simply wrapping an underlying my_ml_fun
.
By default, call_with_args
will pass all arguments in the argument list
matched by the definition of the function that it is calling. This allows the
learner to silently drop irrelevant parameters from the call to my_ml_fun
.
Some learners either capture important arguments using dot arguments (...
) or
by passing important arguments through such dot arguments on to a secondary
function. Both of these cases can be handled using the other_valid
and
keep_all
options to call_with_args
. The former allows you to list other
valid arguments and the latter disables argument filtering altogether.
private$.predict
¶This is the main prediction function, and takes in a task and generates
predictions for that task using the fit_object
. If those predictions are
1-dimensional, they will be coerced to a vector by base_predict
.
private$.chain
¶This is the main chaining function. It takes in a task and generates a chained
task (based on the input task) using the given fit_object
. If this method is
not implemented, your learner will use the default chaining behavior, which is
to return a new task where the covariates are defined as your learner's
predictions for the current task.
If you want other people to be able to use your learner, you will need to document and provide unit tests for it. The above template has example documentation, written in the roxygen format. Most importantly, you should describe what your learner does, reference any external code it uses, and document any parameters and public methods defined by it.
It's also important to test your learner.
You should write unit tests to verify that your learner can train and predict on
new data, and, if applicable, generate a chained task. It might also be a good
idea to use the risk
function in sl3
to verify your learner's performance on
a sample dataset. That way, if you change your learner and performance drops,
you know something may have gone wrong.
sl3
¶Once you've implemented your new learner (and made sure that it has quality
documentation and unit tests), please consider adding it to the sl3
project.
This will make it possible for other sl3
users to use and build on your work.
Make sure to add any R packages listed in .required_packages
to the
Suggests:
field of the DESCRIPTION
file of the sl3
package. Once this is
done, please submit a Pull Request to the sl3
package on
GitHub to request that your learned be
added. If you've never made a "Pull Request" before, see this helpful
guide: https://yangsu.github.io/pull-request-tutorial/.
##' Template of a \code{sl3} Learner.
##'
##' This is a template for defining a new learner.
##' This can be copied to a new file using \code{\link{write_learner_template}}.
##' The remainder of this documentation is an example of how you might write documentation for your new learner.
##' This learner uses \code{\link[my_package]{my_ml_fun}} from \code{my_package} to fit my favorite machine learning algorithm.
##'
##' @docType class
##' @importFrom R6 R6Class
##' @export
##' @keywords data
##' @return Learner object with methods for training and prediction. See \code{\link{Lrnr_base}} for documentation on learners.
##' @format \code{\link{R6Class}} object.
##' @family Learners
##'
##' @section Parameters:
##' \describe{
##' \item{\code{param_1="default_1"}}{ This parameter does something.
##' }
##' \item{\code{param_2="default_2"}}{ This parameter does something else.
##' }
##' \item{\code{...}}{ Other parameters passed directly to \code{\link[my_package]{my_ml_fun}}. See its documentation for details.
##' }
##' }
##'
##' @section Methods:
##' \describe{
##' \item{\code{special_function(arg_1)}}{
##' My learner is special so it has a special function.
##'
##' \itemize{
##' \item{\code{arg_1}: A very special argument.
##' }
##' }
##' }
##' }
Lrnr_template <- R6Class(classname = "Lrnr_template", inherit = Lrnr_base,
portable = TRUE, class = TRUE,
# Above, you should change Lrnr_template (in both the object name and the classname argument)
# to a name that indicates what your learner does
public = list(
# you can define default parameter values here
# if possible, your learner should define defaults for all required parameters
initialize = function(param_1="default_1", param_2="default_2", ...) {
# this captures all parameters to initialize and saves them as self$params
params <- args_to_list()
super$initialize(params = params, ...)
},
# you can define public functions that allow your learner to do special things here
# for instance glm learner might return prediction standard errors
special_function = function(arg_1){
}
),
private = list(
# list properties your learner supports here.
# Use sl3_list_properties() for a list of options
.properties = c(""),
# list any packages required for your learner here.
.required_packages = c("my_package"),
# .train takes task data and returns a fit object that can be used to generate predictions
.train = function(task) {
# generate an argument list from the parameters that were
# captured when your learner was initialized.
# this allows users to pass arguments directly to your ml function
args <- self$params
# get outcome variable type
# prefering learner$params$outcome_type first, then task$outcome_type
outcome_type <- self$get_outcome_type(task)
# should pass something on to your learner indicating outcome_type
# e.g. family or objective
# add task data to the argument list
# what these arguments are called depends on the learner you are wrapping
args$x <- as.matrix(task$X_intercept)
args$y <- outcome_type$format(task$Y)
# only add arguments on weights and offset
# if those were specified when the task was generated
if(task$has_node("weights")){
args$weights <- task$weights
}
if(task$has_node("offset")){
args$offset <- task$offset
}
# call a function that fits your algorithm
# with the argument list you constructed
fit_object <- call_with_args(my_ml_fun, args)
# return the fit object, which will be stored
# in a learner object and returned from the call
# to learner$predict
return(fit_object)
},
# .predict takes a task and returns predictions from that task
.predict = function(task = NULL) {
self$training_task
self$training_outcome_type
self$fit_object
predictions <- predict(self$fit_object, task$X)
return(predictions)
}
)
)