#!/usr/bin/env python # coding: utf-8 # # Basic usage # *`skorch`* is designed to maximize interoperability between `sklearn` and `pytorch`. The aim is to keep 99% of the flexibility of `pytorch` while being able to leverage most features of `sklearn`. Below, we show the basic usage of `skorch` and how it can be combined with `sklearn`. # #
# # Run in Google Colab # # View source on GitHub
# This notebook shows you how to use the basic functionality of `skorch`. # ### Table of contents # * [Definition of the pytorch module](#Definition-of-the-pytorch-module) # * [Training a classifier](#Training-a-classifier-and-making-predictions) # * [Dataset](#A-toy-binary-classification-task) # * [pytorch module](#Definition-of-the-pytorch-classification-module) # * [Model training](#Defining-and-training-the-neural-net-classifier) # * [Inference](#Making-predictions,-classification) # * [Training a regressor](#Training-a-regressor) # * [Dataset](#A-toy-regression-task) # * [pytorch module](#Definition-of-the-pytorch-regression-module) # * [Model training](#Defining-and-training-the-neural-net-regressor) # * [Inference](#Making-predictions,-regression) # * [Saving and loading a model](#Saving-and-loading-a-model) # * [Whole model](#Saving-the-whole-model) # * [Only parameters](#Saving-only-the-model-parameters) # * [Usage with an sklearn Pipeline](#Usage-with-an-sklearn-Pipeline) # * [Callbacks](#Callbacks) # * [Grid search](#Usage-with-sklearn-GridSearchCV) # * [Special prefixes](#Special-prefixes) # * [Performing a grid search](#Performing-a-grid-search) # In[1]: get_ipython().system(' [ ! -z "$COLAB_GPU" ] && pip install torch skorch') # In[2]: import torch from torch import nn import torch.nn.functional as F torch.manual_seed(0); # ## Training a classifier and making predictions # ### A toy binary classification task # We load a toy classification task from `sklearn`. # In[3]: import numpy as np from sklearn.datasets import make_classification # In[4]: X, y = make_classification(1000, 20, n_informative=10, random_state=0) X = X.astype(np.float32) # In[5]: X.shape, y.shape, y.mean() # ### Definition of the `pytorch` classification `module` # We define a vanilla neural network with two hidden layers. The output layer should have 2 output units since there are two classes. In addition, it should have a softmax nonlinearity, because later, when calling `predict_proba`, the output from the `forward` call will be used. # In[6]: class ClassifierModule(nn.Module): def __init__( self, num_units=10, nonlin=F.relu, dropout=0.5, ): super(ClassifierModule, self).__init__() self.num_units = num_units self.nonlin = nonlin self.dropout = dropout self.dense0 = nn.Linear(20, num_units) self.nonlin = nonlin self.dropout = nn.Dropout(dropout) self.dense1 = nn.Linear(num_units, 10) self.output = nn.Linear(10, 2) def forward(self, X, **kwargs): X = self.nonlin(self.dense0(X)) X = self.dropout(X) X = F.relu(self.dense1(X)) X = F.softmax(self.output(X), dim=-1) return X # ### Defining and training the neural net classifier # We use `NeuralNetClassifier` because we're dealing with a classifcation task. The first argument should be the `pytorch module`. As additional arguments, we pass the number of epochs and the learning rate (`lr`), but those are optional. # # *Note*: To use the CUDA backend, pass `device='cuda'` as an additional argument. # In[7]: from skorch import NeuralNetClassifier # In[8]: net = NeuralNetClassifier( ClassifierModule, max_epochs=20, lr=0.1, # device='cuda', # uncomment this to train with CUDA ) # As in `sklearn`, we call `fit` passing the input data `X` and the targets `y`. By default, `NeuralNetClassifier` makes a `StratifiedKFold` split on the data (80/20) to track the validation loss. This is shown, as well as the train loss and the accuracy on the validation set. # In[9]: pdb on # In[10]: net.fit(X, y) # Also, as in `sklearn`, you may call `predict` or `predict_proba` on the fitted model. # ### Making predictions, classification # In[11]: y_pred = net.predict(X[:5]) y_pred # In[12]: y_proba = net.predict_proba(X[:5]) y_proba # ## Training a regressor # ### A toy regression task # In[13]: from sklearn.datasets import make_regression # In[14]: X_regr, y_regr = make_regression(1000, 20, n_informative=10, random_state=0) X_regr = X_regr.astype(np.float32) y_regr = y_regr.astype(np.float32) / 100 y_regr = y_regr.reshape(-1, 1) # In[15]: X_regr.shape, y_regr.shape, y_regr.min(), y_regr.max() # *Note*: Regression currently requires the target to be 2-dimensional, hence the need to reshape. This should be fixed with an upcoming version of pytorch. # ### Definition of the `pytorch` regression `module` # Again, define a vanilla neural network with two hidden layers. The main difference is that the output layer only has one unit and does not apply a softmax nonlinearity. # In[16]: class RegressorModule(nn.Module): def __init__( self, num_units=10, nonlin=F.relu, ): super(RegressorModule, self).__init__() self.num_units = num_units self.nonlin = nonlin self.dense0 = nn.Linear(20, num_units) self.nonlin = nonlin self.dense1 = nn.Linear(num_units, 10) self.output = nn.Linear(10, 1) def forward(self, X, **kwargs): X = self.nonlin(self.dense0(X)) X = F.relu(self.dense1(X)) X = self.output(X) return X # ### Defining and training the neural net regressor # Training a regressor is almost the same as training a classifier. Mainly, we use `NeuralNetRegressor` instead of `NeuralNetClassifier` (this is the same terminology as in `sklearn`). # In[17]: from skorch import NeuralNetRegressor # In[18]: net_regr = NeuralNetRegressor( RegressorModule, max_epochs=20, lr=0.1, # device='cuda', # uncomment this to train with CUDA ) # In[19]: net_regr.fit(X_regr, y_regr) # ### Making predictions, regression # You may call `predict` or `predict_proba` on the fitted model. For regressions, both methods return the same value. # In[20]: y_pred = net_regr.predict(X_regr[:5]) y_pred # ## Saving and loading a model # Save and load either the whole model by using pickle or just the learned model parameters by calling `save_params` and `load_params`. # ### Saving the whole model # In[21]: import pickle # In[22]: file_name = '/tmp/mymodel.pkl' # In[23]: with open(file_name, 'wb') as f: pickle.dump(net, f) # In[24]: with open(file_name, 'rb') as f: new_net = pickle.load(f) # ### Saving only the model parameters # This only saves and loads the proper `module` parameters, meaning that hyperparameters such as `lr` and `max_epochs` are not saved. Therefore, to load the model, we have to re-initialize it beforehand. # In[25]: net.save_params(f_params=file_name) # a file handler also works # In[26]: # first initialize the model new_net = NeuralNetClassifier( ClassifierModule, max_epochs=20, lr=0.1, ).initialize() # In[27]: new_net.load_params(file_name) # ## Usage with an `sklearn Pipeline` # It is possible to put the `NeuralNetClassifier` inside an `sklearn Pipeline`, as you would with any `sklearn` classifier. # In[28]: from sklearn.pipeline import Pipeline from sklearn.preprocessing import StandardScaler # In[29]: pipe = Pipeline([ ('scale', StandardScaler()), ('net', net), ]) # In[30]: pipe.fit(X, y) # In[31]: y_proba = pipe.predict_proba(X[:5]) y_proba # To save the whole pipeline, including the pytorch module, use `pickle`. # ## Callbacks # Adding a new callback to the model is straightforward. Below we show how to add a new callback that determines the area under the ROC (AUC) score. # In[32]: from skorch.callbacks import EpochScoring # There is a scoring callback in skorch, `EpochScoring`, which we use for this. We have to specify which score to calculate. We have 3 choices: # # * Passing a string: This should be a valid `sklearn` metric. For a list of all existing scores, look [here](http://scikit-learn.org/stable/modules/classes.html#sklearn-metrics-metrics). # * Passing `None`: If you implement your own `.score` method on your neural net, passing `scoring=None` will tell `skorch` to use that. # * Passing a function or callable: If we want to define our own scoring function, we pass a function with the signature `func(model, X, y) -> score`, which is then used. # # Note that this works exactly the same as scoring in `sklearn` does. # For our case here, since `sklearn` already implements AUC, we just pass the correct string `'roc_auc'`. We should also tell the callback that higher scores are better (to get the correct colors printed below -- by default, lower scores are assumed to be better). Furthermore, we may specify a `name` argument for `EpochScoring`, and whether to use training data (by setting `on_train=True`) or validation data (which is the default). # In[33]: auc = EpochScoring(scoring='roc_auc', lower_is_better=False) # Finally, we pass the scoring callback to the `callbacks` parameter as a list and then call `fit`. Notice that we get the printed scores and color highlighting for free. # In[34]: net = NeuralNetClassifier( ClassifierModule, max_epochs=20, lr=0.1, callbacks=[auc], ) # In[35]: net.fit(X, y) # For information on how to write custom callbacks, have a look at the [Advanced_Usage](https://nbviewer.jupyter.org/github/dnouri/skorch/blob/master/notebooks/Advanced_Usage.ipynb) notebook. # ## Usage with sklearn `GridSearchCV` # ### Special prefixes # The `NeuralNet` class allows to directly access parameters of the `pytorch module` by using the `module__` prefix. So e.g. if you defined the `module` to have a `num_units` parameter, you can set it via the `module__num_units` argument. This is exactly the same logic that allows to access estimator parameters in `sklearn Pipeline`s and `FeatureUnion`s. # This feature is useful in several ways. For one, it allows to set those parameters in the model definition. Furthermore, it allows you to set parameters in an `sklearn GridSearchCV` as shown below. # In addition to the parameters prefixed by `module__`, you may access a couple of other attributes, such as those of the optimizer by using the `optimizer__` prefix (again, see below). All those special prefixes are stored in the `prefixes_` attribute: # In[36]: print(', '.join(net.prefixes_)) # ### Performing a grid search # Below we show how to perform a grid search over the learning rate (`lr`), the module's number of hidden units (`module__num_units`), the module's dropout rate (`module__dropout`), and whether the SGD optimizer should use Nesterov momentum or not (`optimizer__nesterov`). # In[37]: from sklearn.model_selection import GridSearchCV # In[38]: net = NeuralNetClassifier( ClassifierModule, max_epochs=20, lr=0.1, verbose=0, optimizer__momentum=0.9, ) # In[39]: params = { 'lr': [0.05, 0.1], 'module__num_units': [10, 20], 'module__dropout': [0, 0.5], 'optimizer__nesterov': [False, True], } # In[40]: gs = GridSearchCV(net, params, refit=False, cv=3, scoring='accuracy', verbose=2) # In[41]: gs.fit(X, y) # In[42]: print(gs.best_score_, gs.best_params_) # Of course, we could further nest the `NeuralNetClassifier` within an `sklearn Pipeline`, in which case we just prefix the parameter by the name of the net (e.g. `net__module__num_units`).