from sklearn.preprocessing import LabelEncoder
import torch
import pandas as pd
import numpy as np
np.random.seed(0)
import os
import wget
from pathlib import Path
from matplotlib import pyplot as plt
%matplotlib inline
from pytorch_tabnet.tab_model import TabNetClassifier
url = "https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data"
dataset_name = 'census-income'
out = Path(os.getcwd()+'/data/'+dataset_name+'.csv')
out.parent.mkdir(parents=True, exist_ok=True)
if out.exists():
print("File already exists.")
else:
print("Downloading file...")
wget.download(url, out.as_posix())
File already exists.
train = pd.read_csv(out)
target = ' <=50K'
if "Set" not in train.columns:
train["Set"] = np.random.choice(["train", "valid", "test"], p =[.8, .1, .1], size=(train.shape[0],))
train_indices = train[train.Set=="train"].index
valid_indices = train[train.Set=="valid"].index
test_indices = train[train.Set=="test"].index
nunique = train.nunique()
types = train.dtypes
categorical_columns = []
categorical_dims = {}
for col in train.columns:
if types[col] == 'object' or nunique[col] < 200:
print(col, train[col].nunique())
l_enc = LabelEncoder()
train[col] = train[col].fillna("VV_likely")
train[col] = l_enc.fit_transform(train[col].values)
categorical_columns.append(col)
categorical_dims[col] = len(l_enc.classes_)
else:
train.fillna(train.loc[train_indices, col].mean(), inplace=True)
39 73 State-gov 9 Bachelors 16 13 16 Never-married 7 Adm-clerical 15 Not-in-family 6 White 5 Male 2 2174 119 0 92 40 94 United-States 42 <=50K 2 Set 3
# check that pipeline accepts strings
train.loc[train[target]==0, target] = "wealthy"
train.loc[train[target]==1, target] = "not_wealthy"
unused_feat = ['Set']
features = [ col for col in train.columns if col not in unused_feat+[target]]
cat_idxs = [ i for i, f in enumerate(features) if f in categorical_columns]
cat_dims = [ categorical_dims[f] for i, f in enumerate(features) if f in categorical_columns]
clf = TabNetClassifier(cat_idxs=cat_idxs,
cat_dims=cat_dims,
cat_emb_dim=1,
optimizer_fn=torch.optim.Adam,
optimizer_params=dict(lr=2e-2),
scheduler_params={"step_size":50, # how to use learning rate scheduler
"gamma":0.9},
scheduler_fn=torch.optim.lr_scheduler.StepLR,
mask_type='entmax' # "sparsemax"
)
Device used : cuda
X_train = train[features].values[train_indices]
y_train = train[target].values[train_indices]
X_valid = train[features].values[valid_indices]
y_valid = train[target].values[valid_indices]
X_test = train[features].values[test_indices]
y_test = train[target].values[test_indices]
max_epochs = 1000 if not os.getenv("CI", False) else 2
clf.fit(
X_train=X_train, y_train=y_train,
eval_set=[(X_train, y_train), (X_valid, y_valid)],
eval_name=['train', 'valid'],
eval_metric=['auc'],
max_epochs=max_epochs , patience=20,
batch_size=1024, virtual_batch_size=128,
num_workers=0,
weights=1,
drop_last=False
)
epoch 0 | loss: 0.668 | train_auc: 0.75705 | valid_auc: 0.7551 | 0:00:01s epoch 1 | loss: 0.52031 | train_auc: 0.81912 | valid_auc: 0.82696 | 0:00:04s epoch 2 | loss: 0.47527 | train_auc: 0.84816 | valid_auc: 0.85195 | 0:00:06s epoch 3 | loss: 0.45715 | train_auc: 0.86756 | valid_auc: 0.86571 | 0:00:08s epoch 4 | loss: 0.43029 | train_auc: 0.88064 | valid_auc: 0.87487 | 0:00:10s epoch 5 | loss: 0.41997 | train_auc: 0.89128 | valid_auc: 0.8849 | 0:00:12s epoch 6 | loss: 0.40586 | train_auc: 0.898 | valid_auc: 0.88995 | 0:00:14s epoch 7 | loss: 0.40141 | train_auc: 0.90266 | valid_auc: 0.89769 | 0:00:16s epoch 8 | loss: 0.39187 | train_auc: 0.90459 | valid_auc: 0.8956 | 0:00:18s epoch 9 | loss: 0.37791 | train_auc: 0.91019 | valid_auc: 0.90593 | 0:00:21s epoch 10 | loss: 0.37631 | train_auc: 0.91394 | valid_auc: 0.90945 | 0:00:23s epoch 11 | loss: 0.36412 | train_auc: 0.91093 | valid_auc: 0.90707 | 0:00:25s epoch 12 | loss: 0.3587 | train_auc: 0.91243 | valid_auc: 0.90965 | 0:00:27s epoch 13 | loss: 0.35557 | train_auc: 0.915 | valid_auc: 0.90905 | 0:00:29s epoch 14 | loss: 0.34672 | train_auc: 0.9182 | valid_auc: 0.91487 | 0:00:31s epoch 15 | loss: 0.35145 | train_auc: 0.92211 | valid_auc: 0.91805 | 0:00:33s epoch 16 | loss: 0.34199 | train_auc: 0.92471 | valid_auc: 0.92013 | 0:00:35s epoch 17 | loss: 0.3372 | train_auc: 0.9272 | valid_auc: 0.92226 | 0:00:37s epoch 18 | loss: 0.34344 | train_auc: 0.92886 | valid_auc: 0.92452 | 0:00:39s epoch 19 | loss: 0.34549 | train_auc: 0.92919 | valid_auc: 0.92233 | 0:00:41s epoch 20 | loss: 0.33269 | train_auc: 0.93105 | valid_auc: 0.92654 | 0:00:43s epoch 21 | loss: 0.32923 | train_auc: 0.93199 | valid_auc: 0.92505 | 0:00:45s epoch 22 | loss: 0.33069 | train_auc: 0.93208 | valid_auc: 0.92693 | 0:00:47s epoch 23 | loss: 0.3301 | train_auc: 0.93287 | valid_auc: 0.92766 | 0:00:49s epoch 24 | loss: 0.33326 | train_auc: 0.93347 | valid_auc: 0.92745 | 0:00:51s epoch 25 | loss: 0.32665 | train_auc: 0.93452 | valid_auc: 0.92802 | 0:00:53s epoch 26 | loss: 0.32089 | train_auc: 0.93444 | valid_auc: 0.92747 | 0:00:55s epoch 27 | loss: 0.32657 | train_auc: 0.93284 | valid_auc: 0.92749 | 0:00:57s epoch 28 | loss: 0.32863 | train_auc: 0.93331 | valid_auc: 0.92529 | 0:00:59s epoch 29 | loss: 0.32456 | train_auc: 0.93459 | valid_auc: 0.92775 | 0:01:01s epoch 30 | loss: 0.3245 | train_auc: 0.93506 | valid_auc: 0.92776 | 0:01:03s epoch 31 | loss: 0.31973 | train_auc: 0.93558 | valid_auc: 0.92732 | 0:01:05s epoch 32 | loss: 0.32807 | train_auc: 0.9334 | valid_auc: 0.92574 | 0:01:08s epoch 33 | loss: 0.32806 | train_auc: 0.93508 | valid_auc: 0.92774 | 0:01:10s epoch 34 | loss: 0.31981 | train_auc: 0.93656 | valid_auc: 0.93014 | 0:01:12s epoch 35 | loss: 0.31738 | train_auc: 0.93678 | valid_auc: 0.92766 | 0:01:14s epoch 36 | loss: 0.3209 | train_auc: 0.93637 | valid_auc: 0.92766 | 0:01:16s epoch 37 | loss: 0.31531 | train_auc: 0.93336 | valid_auc: 0.92297 | 0:01:18s epoch 38 | loss: 0.3231 | train_auc: 0.93368 | valid_auc: 0.92438 | 0:01:20s epoch 39 | loss: 0.31914 | train_auc: 0.93741 | valid_auc: 0.92685 | 0:01:23s epoch 40 | loss: 0.31784 | train_auc: 0.93709 | valid_auc: 0.92647 | 0:01:25s epoch 41 | loss: 0.32154 | train_auc: 0.93775 | valid_auc: 0.92521 | 0:01:27s epoch 42 | loss: 0.31726 | train_auc: 0.93814 | valid_auc: 0.92743 | 0:01:29s epoch 43 | loss: 0.31768 | train_auc: 0.93822 | valid_auc: 0.9265 | 0:01:31s epoch 44 | loss: 0.31297 | train_auc: 0.93664 | valid_auc: 0.92333 | 0:01:33s epoch 45 | loss: 0.31219 | train_auc: 0.93833 | valid_auc: 0.92682 | 0:01:35s epoch 46 | loss: 0.31816 | train_auc: 0.93877 | valid_auc: 0.92526 | 0:01:37s epoch 47 | loss: 0.3168 | train_auc: 0.93903 | valid_auc: 0.92521 | 0:01:39s epoch 48 | loss: 0.31014 | train_auc: 0.93864 | valid_auc: 0.92364 | 0:01:41s epoch 49 | loss: 0.31637 | train_auc: 0.93793 | valid_auc: 0.92628 | 0:01:43s epoch 50 | loss: 0.31441 | train_auc: 0.9398 | valid_auc: 0.92782 | 0:01:45s epoch 51 | loss: 0.30673 | train_auc: 0.94062 | valid_auc: 0.92624 | 0:01:48s epoch 52 | loss: 0.30835 | train_auc: 0.94006 | valid_auc: 0.92509 | 0:01:50s epoch 53 | loss: 0.30838 | train_auc: 0.94081 | valid_auc: 0.92882 | 0:01:52s epoch 54 | loss: 0.31133 | train_auc: 0.94049 | valid_auc: 0.92622 | 0:01:55s Early stopping occurred at epoch 54 with best_epoch = 34 and best_valid_auc = 0.93014 Best weights from best epoch are automatically used!
import sklearn
import skorch
from skorch.helper import predefined_split
import pytorch_tabnet
from pytorch_tabnet.multiclass_utils import infer_output_dim
from pytorch_tabnet.tab_network import TabNet
from pytorch_tabnet.utils import create_explain_matrix
from torch.nn import CrossEntropyLoss
from scipy.sparse import csc_matrix
class SkorchTabModel(skorch.NeuralNet):
def __init__(
self,
criterion,
module=TabNet,
module__input_dim=100,
module__output_dim=5,
**kwargs,
):
super().__init__(
module,
criterion,
module__input_dim=module__input_dim,
module__output_dim=module__output_dim,
**kwargs,
)
def initialize_module(self):
"""Setup the network and explain matrix."""
kwargs = self.get_params_for('module')
self.module_ = TabNet(**kwargs).to(self.device)
self.reducing_matrix_ = create_explain_matrix(
self.module_.input_dim,
self.module_.cat_emb_dim,
self.module_.cat_idxs,
self.module_.post_embed_dim,
)
def compute_feature_importances(self, X):
"""Compute global feature importance."""
feature_importances_ = np.zeros((self.module_.post_embed_dim))
for (M_explain, masks) in self.forward_masks_iter(X):
feature_importances_ += M_explain.sum(dim=0).cpu().detach().numpy()
feature_importances_ = csc_matrix.dot(
feature_importances_, self.reducing_matrix_,
)
return feature_importances_ / np.sum(feature_importances_)
def on_train_end(self, net, X, **kwargs):
self.feature_importances_ = self.compute_feature_importances(X)
super().on_train_end(net, X=X, **kwargs)
def forward_masks_iter(self, X, training=False, device='cpu'):
# based on the forward_iter recipe in skorch.NeuralNet
dataset = self.get_dataset(X)
iterator = self.get_iterator(dataset, training=training)
for data in iterator:
Xi = skorch.dataset.unpack_data(data)[0]
Xi = skorch.utils.to_device(Xi, self.device)
with torch.set_grad_enabled(False):
yp = self.module_.forward_masks(Xi)
yield skorch.utils.to_device(yp, device=device)
def explain(self, X):
"""
Return local explanation
Parameters
----------
X : tensor: `torch.Tensor`
Input data
Returns
-------
M_explain : matrix
Importance per sample, per columns.
masks : matrix
Sparse matrix showing attention masks used by network.
"""
res_explain = []
for i, (M_explain, masks) in enumerate(self.forward_masks_iter(X)):
for key, value in masks.items():
masks[key] = csc_matrix.dot(
value.cpu().detach().numpy(), self.reducing_matrix_
)
res_explain.append(
csc_matrix.dot(M_explain.cpu().detach().numpy(), self.reducing_matrix_)
)
if i == 0:
res_masks = masks
else:
for key, value in masks.items():
res_masks[key] = np.vstack([res_masks[key], value])
res_explain = np.vstack(res_explain)
return res_explain, res_masks
def predict(self, X):
y_proba = self.predict_proba(X)
return y_proba.argmax(-1)
class CrossEntropySparsityLoss(torch.nn.CrossEntropyLoss):
def __init__(self, lambda_sparse=1e-3):
super().__init__()
self.lambda_sparse = lambda_sparse
def forward(self, y_pred, y_true):
output, M_loss = y_pred
loss = super().forward(output, y_true)
# Add the overall sparsity loss
loss -= self.lambda_sparse * M_loss
return loss
label_encoder = LabelEncoder()
label_encoder.fit(y_train)
LabelEncoder()
y_train_enc = label_encoder.transform(y_train)
y_valid_enc = label_encoder.transform(y_valid)
torch.manual_seed(0)
skorch_clf = SkorchTabModel(
criterion=CrossEntropySparsityLoss,
module__input_dim=X_train.shape[-1],
module__output_dim=infer_output_dim(y_train)[0],
module__cat_idxs=cat_idxs,
module__cat_dims=cat_dims,
module__cat_emb_dim=1,
module__mask_type='entmax', # "sparsemax"
module__virtual_batch_size=128,
optimizer=torch.optim.Adam,
optimizer__lr=2e-2,
batch_size=1024,
iterator_train__num_workers=0,
iterator_train__drop_last=False,
iterator_valid__num_workers=0,
iterator_valid__drop_last=False,
callbacks=[
skorch.callbacks.LRScheduler(
policy=torch.optim.lr_scheduler.StepLR,
step_size=50,
gamma=0.9,
),
skorch.callbacks.EarlyStopping(patience=20),
skorch.callbacks.GradientNormClipping(gradient_clip_value=1.),
skorch.callbacks.EpochScoring('roc_auc'),
],
train_split=predefined_split(skorch.dataset.Dataset(X_valid, y_valid_enc)),
max_epochs=max_epochs,
device='cuda',
)
%pdb on
skorch_clf.fit(
X_train,
y_train_enc,
#weights=1,
)
Automatic pdb calling has been turned ON epoch roc_auc train_loss valid_loss lr dur ------- --------- ------------ ------------ ------ ------ 1 0.7858 0.4923 0.4510 0.0200 1.7566 2 0.8283 0.4203 0.4172 0.0200 1.6597 3 0.8489 0.3933 0.4054 0.0200 1.6657 4 0.8799 0.3667 0.3636 0.0200 1.6300 5 0.8879 0.3486 0.3566 0.0200 1.5580 6 0.8971 0.3361 0.3452 0.0200 1.5972 7 0.8968 0.3264 0.3479 0.0200 1.7318 8 0.9023 0.3173 0.3471 0.0200 1.8923 9 0.9015 0.3080 0.3678 0.0200 1.7681 10 0.9028 0.3016 0.3908 0.0200 1.7663 11 0.9121 0.2941 0.3655 0.0200 1.7459 12 0.9142 0.2917 0.3332 0.0200 1.7484 13 0.9140 0.2882 0.3576 0.0200 1.7895 14 0.9159 0.2852 0.3315 0.0200 1.7832 15 0.9207 0.2837 0.3088 0.0200 1.6770 16 0.9213 0.2793 0.3086 0.0200 1.7683 17 0.9219 0.2770 0.3065 0.0200 1.6309 18 0.9247 0.2757 0.2987 0.0200 1.6260 19 0.9238 0.2729 0.2896 0.0200 1.6187 20 0.9227 0.2747 0.2954 0.0200 1.6966 21 0.9226 0.2806 0.2897 0.0200 1.5687 22 0.9247 0.2760 0.2892 0.0200 1.5699 23 0.9242 0.2742 0.2947 0.0200 1.5549 24 0.9211 0.2751 0.2942 0.0200 1.6073 25 0.9097 0.2743 0.3750 0.0200 1.4919 26 0.9245 0.2718 0.2939 0.0200 1.5195 27 0.9244 0.2692 0.2946 0.0200 1.4851 28 0.9271 0.2683 0.2819 0.0200 1.6002 29 0.9286 0.2659 0.2810 0.0200 1.5616 30 0.9286 0.2666 0.2800 0.0200 1.5541 31 0.9261 0.2642 0.2908 0.0200 1.5674 32 0.9253 0.2646 0.3013 0.0200 1.5713 33 0.9282 0.2636 0.2863 0.0200 1.5773 34 0.9255 0.2647 0.2921 0.0200 1.5403 35 0.9272 0.2641 0.2850 0.0200 1.5496 36 0.9276 0.2609 0.2820 0.0200 1.5983 37 0.9248 0.2602 0.2896 0.0200 1.5979 38 0.9270 0.2588 0.2852 0.0200 1.5543 39 0.9270 0.2565 0.2878 0.0200 1.7513 40 0.9245 0.2573 0.3014 0.0200 1.6232 41 0.9238 0.2578 0.2998 0.0200 1.5392 42 0.9272 0.2561 0.3037 0.0200 1.6411 43 0.9239 0.2554 0.2959 0.0200 1.5845 44 0.9254 0.2571 0.2871 0.0200 1.5598 45 0.9253 0.2557 0.2859 0.0200 1.9085 46 0.9263 0.2575 0.2877 0.0200 1.7582 47 0.9226 0.2562 0.3261 0.0200 1.7150 48 0.9260 0.2555 0.2969 0.0200 1.6852 49 0.9257 0.2572 0.2902 0.0200 1.6267 Stopping since valid_loss has not improved in the last 20 epochs.
<class '__main__.SkorchTabModel'>[initialized]( module_=TabNet( (embedder): EmbeddingGenerator( (embeddings): ModuleList( (0): Embedding(73, 1) (1): Embedding(9, 1) (2): Embedding(16, 1) (3): Embedding(16, 1) (4): Embedding(7, 1) (5): Embedding(15, 1) (6): Embedding(6, 1) (7): Embedding(5, 1) (8): Embedding(2, 1) (9): Embedding(119, 1) (10): Embedding(92, 1) (11): Embedding(94, 1) (12): Embedding(42, 1) ) ) (tabnet): TabNetNoEmbeddings( (initial_bn): BatchNorm1d(14, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True) (encoder): TabNetEncoder( (initial_bn): BatchNorm1d(14, eps=1e-05, momentum=0.01, affine=True, track_running_stats=True) (initial_splitter): FeatTransformer( (shared): GLU_Block( (shared_layers): ModuleList( (0): Linear(in_features=14, out_features=32, bias=False) (1): Linear(in_features=16, out_features=32, bias=False) ) (glu_layers): ModuleList( (0): GLU_Layer( (fc): Linear(in_features=14, out_features=32, bias=False) (bn): GBN( (bn): BatchNorm1d(32, eps=1e-05, momentum=0.02, affine=True, track_running_stats=True) ) ) (1): GLU_Layer( (fc): Linear(in_features=16, out_features=32, bias=False) (bn): GBN( (bn): BatchNorm1d(32, eps=1e-05, momentum=0.02, affine=True, track_running_stats=True) ) ) ) ) (specifics): GLU_Block( (glu_layers): ModuleList( (0): GLU_Layer( (fc): Linear(in_features=16, out_features=32, bias=False) (bn): GBN( (bn): BatchNorm1d(32, eps=1e-05, momentum=0.02, affine=True, track_running_stats=True) ) ) (1): GLU_Layer( (fc): Linear(in_features=16, out_features=32, bias=False) (bn): GBN( (bn): BatchNorm1d(32, eps=1e-05, momentum=0.02, affine=True, track_running_stats=True) ) ) ) ) ) (feat_transformers): ModuleList( (0): FeatTransformer( (shared): GLU_Block( (shared_layers): ModuleList( (0): Linear(in_features=14, out_features=32, bias=False) (1): Linear(in_features=16, out_features=32, bias=False) ) (glu_layers): ModuleList( (0): GLU_Layer( (fc): Linear(in_features=14, out_features=32, bias=False) (bn): GBN( (bn): BatchNorm1d(32, eps=1e-05, momentum=0.02, affine=True, track_running_stats=True) ) ) (1): GLU_Layer( (fc): Linear(in_features=16, out_features=32, bias=False) (bn): GBN( (bn): BatchNorm1d(32, eps=1e-05, momentum=0.02, affine=True, track_running_stats=True) ) ) ) ) (specifics): GLU_Block( (glu_layers): ModuleList( (0): GLU_Layer( (fc): Linear(in_features=16, out_features=32, bias=False) (bn): GBN( (bn): BatchNorm1d(32, eps=1e-05, momentum=0.02, affine=True, track_running_stats=True) ) ) (1): GLU_Layer( (fc): Linear(in_features=16, out_features=32, bias=False) (bn): GBN( (bn): BatchNorm1d(32, eps=1e-05, momentum=0.02, affine=True, track_running_stats=True) ) ) ) ) ) (1): FeatTransformer( (shared): GLU_Block( (shared_layers): ModuleList( (0): Linear(in_features=14, out_features=32, bias=False) (1): Linear(in_features=16, out_features=32, bias=False) ) (glu_layers): ModuleList( (0): GLU_Layer( (fc): Linear(in_features=14, out_features=32, bias=False) (bn): GBN( (bn): BatchNorm1d(32, eps=1e-05, momentum=0.02, affine=True, track_running_stats=True) ) ) (1): GLU_Layer( (fc): Linear(in_features=16, out_features=32, bias=False) (bn): GBN( (bn): BatchNorm1d(32, eps=1e-05, momentum=0.02, affine=True, track_running_stats=True) ) ) ) ) (specifics): GLU_Block( (glu_layers): ModuleList( (0): GLU_Layer( (fc): Linear(in_features=16, out_features=32, bias=False) (bn): GBN( (bn): BatchNorm1d(32, eps=1e-05, momentum=0.02, affine=True, track_running_stats=True) ) ) (1): GLU_Layer( (fc): Linear(in_features=16, out_features=32, bias=False) (bn): GBN( (bn): BatchNorm1d(32, eps=1e-05, momentum=0.02, affine=True, track_running_stats=True) ) ) ) ) ) (2): FeatTransformer( (shared): GLU_Block( (shared_layers): ModuleList( (0): Linear(in_features=14, out_features=32, bias=False) (1): Linear(in_features=16, out_features=32, bias=False) ) (glu_layers): ModuleList( (0): GLU_Layer( (fc): Linear(in_features=14, out_features=32, bias=False) (bn): GBN( (bn): BatchNorm1d(32, eps=1e-05, momentum=0.02, affine=True, track_running_stats=True) ) ) (1): GLU_Layer( (fc): Linear(in_features=16, out_features=32, bias=False) (bn): GBN( (bn): BatchNorm1d(32, eps=1e-05, momentum=0.02, affine=True, track_running_stats=True) ) ) ) ) (specifics): GLU_Block( (glu_layers): ModuleList( (0): GLU_Layer( (fc): Linear(in_features=16, out_features=32, bias=False) (bn): GBN( (bn): BatchNorm1d(32, eps=1e-05, momentum=0.02, affine=True, track_running_stats=True) ) ) (1): GLU_Layer( (fc): Linear(in_features=16, out_features=32, bias=False) (bn): GBN( (bn): BatchNorm1d(32, eps=1e-05, momentum=0.02, affine=True, track_running_stats=True) ) ) ) ) ) ) (att_transformers): ModuleList( (0): AttentiveTransformer( (fc): Linear(in_features=8, out_features=14, bias=False) (bn): GBN( (bn): BatchNorm1d(14, eps=1e-05, momentum=0.02, affine=True, track_running_stats=True) ) (selector): Entmax15() ) (1): AttentiveTransformer( (fc): Linear(in_features=8, out_features=14, bias=False) (bn): GBN( (bn): BatchNorm1d(14, eps=1e-05, momentum=0.02, affine=True, track_running_stats=True) ) (selector): Entmax15() ) (2): AttentiveTransformer( (fc): Linear(in_features=8, out_features=14, bias=False) (bn): GBN( (bn): BatchNorm1d(14, eps=1e-05, momentum=0.02, affine=True, track_running_stats=True) ) (selector): Entmax15() ) ) ) (final_mapping): Linear(in_features=8, out_features=2, bias=False) ) ), )
sklearn.metrics.roc_auc_score(y_valid_enc, clf.predict_proba(X_valid)[:, 1])
0.9301440270026657
sklearn.metrics.roc_auc_score(y_valid_enc, skorch_clf.predict_proba(X_valid)[:, 1])
0.9251548093171129
sorted(zip(features, clf.feature_importances_), key=lambda x: -x[1])
[(' 2174', 0.1565364299380508), (' 13', 0.1505786961213775), (' Never-married', 0.13932284831968106), ('39', 0.10636645663845869), (' 40', 0.10276376959991625), (' Male', 0.09467168689847826), (' Adm-clerical', 0.08359464814310368), (' Not-in-family', 0.06552531442773961), (' 77516', 0.044812761723332685), (' Bachelors', 0.02024143428103679), (' State-gov', 0.016026824831054484), (' 0', 0.010171940316317284), (' United-States', 0.0052828388496390915), (' White', 0.0041043499118138295)]
sorted(zip(features, skorch_clf.feature_importances_), key=lambda x: -x[1])
[(' Never-married', 0.17752436280807216), (' Not-in-family', 0.12024081127377748), (' 0', 0.10554345933856778), (' 2174', 0.10392617550840204), (' Male', 0.10214954540875482), (' 13', 0.08654041500471207), (' Adm-clerical', 0.06583691769885185), (' United-States', 0.06418859379572349), (' 40', 0.06033358960226785), (' Bachelors', 0.04185132059607215), ('39', 0.03023046687721213), (' State-gov', 0.023454051162537484), (' 77516', 0.010296617146939078), (' White', 0.007883673778109663)]
explain_matrix, masks = clf.explain(X_valid)
fig, axs = plt.subplots(1, 3, figsize=(20,20))
for i in range(3):
axs[i].imshow(masks[i][:50])
axs[i].set_title(f"mask {i}")
axs[i].set_xticks(list(range(len(features))))
axs[i].set_xticklabels(features, rotation=90)
explain_matrix, masks = skorch_clf.explain(X_valid)
fig, axs = plt.subplots(1, 3, figsize=(20,20))
for i in range(3):
axs[i].imshow(masks[i][:50])
axs[i].set_title(f"mask {i}")
axs[i].set_xticks(list(range(len(features))))
axs[i].set_xticklabels(features, rotation=90)