import sys
# install uplift library scikit-uplift and other libraries
!{sys.executable} -m pip install scikit-uplift dill catboost
from IPython.display import clear_output
clear_output()
We are going to use a Lenta dataset
from the BigTarget Hackathon hosted in summer 2020 by Lenta and Microsoft.
Lenta is a russian food retailer.
✏️ Dataset can be loaded from sklift.datasets
module using fetch_lenta
function.
Read more about dataset in the api docs.
This is an uplift modeling dataset containing data about Lenta's customers grociery shopping, marketing campaigns communications as treatment
and store visits as target
.
group
- treatment / control flagresponse_att
- binary targetCardHolder
- customer idgender
- customer genderage
- customer agefrom sklift.datasets import fetch_lenta
# returns sklearn Bunch object
# with data, target, treatment keys
# data features (pd.DataFrame), target (pd.Series), treatment (pd.Series) values
dataset = fetch_lenta()
0%| | 0.00/145M [00:00<?, ?iB/s]
print(f"Dataset type: {type(dataset)}\n")
print(f"Dataset features shape: {dataset.data.shape}")
print(f"Dataset target shape: {dataset.target.shape}")
print(f"Dataset treatment shape: {dataset.treatment.shape}")
Dataset type: <class 'sklearn.utils.Bunch'> Dataset features shape: (687029, 193) Dataset target shape: (687029,) Dataset treatment shape: (687029,)
dataset.data.head().append(dataset.data.tail())
age | cheque_count_12m_g20 | cheque_count_12m_g21 | cheque_count_12m_g25 | cheque_count_12m_g32 | cheque_count_12m_g33 | cheque_count_12m_g38 | cheque_count_12m_g39 | cheque_count_12m_g41 | cheque_count_12m_g42 | cheque_count_12m_g45 | cheque_count_12m_g46 | cheque_count_12m_g48 | cheque_count_12m_g52 | cheque_count_12m_g56 | cheque_count_12m_g57 | cheque_count_12m_g58 | cheque_count_12m_g79 | cheque_count_3m_g20 | cheque_count_3m_g21 | cheque_count_3m_g25 | cheque_count_3m_g42 | cheque_count_3m_g45 | cheque_count_3m_g52 | cheque_count_3m_g56 | cheque_count_3m_g57 | cheque_count_3m_g79 | cheque_count_6m_g20 | cheque_count_6m_g21 | cheque_count_6m_g25 | cheque_count_6m_g32 | cheque_count_6m_g33 | cheque_count_6m_g38 | cheque_count_6m_g39 | cheque_count_6m_g40 | cheque_count_6m_g41 | cheque_count_6m_g42 | cheque_count_6m_g45 | cheque_count_6m_g46 | cheque_count_6m_g48 | ... | perdelta_days_between_visits_15_30d | promo_share_15d | response_sms | response_viber | sale_count_12m_g32 | sale_count_12m_g33 | sale_count_12m_g49 | sale_count_12m_g54 | sale_count_12m_g57 | sale_count_3m_g24 | sale_count_3m_g33 | sale_count_3m_g57 | sale_count_6m_g24 | sale_count_6m_g25 | sale_count_6m_g32 | sale_count_6m_g33 | sale_count_6m_g44 | sale_count_6m_g54 | sale_count_6m_g57 | sale_sum_12m_g24 | sale_sum_12m_g25 | sale_sum_12m_g26 | sale_sum_12m_g27 | sale_sum_12m_g32 | sale_sum_12m_g44 | sale_sum_12m_g54 | sale_sum_3m_g24 | sale_sum_3m_g26 | sale_sum_3m_g32 | sale_sum_3m_g33 | sale_sum_6m_g24 | sale_sum_6m_g25 | sale_sum_6m_g26 | sale_sum_6m_g32 | sale_sum_6m_g33 | sale_sum_6m_g44 | sale_sum_6m_g54 | stdev_days_between_visits_15d | stdev_discount_depth_15d | stdev_discount_depth_1m | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 47.0 | 3.0 | 22.0 | 19.0 | 3.0 | 28.0 | 8.0 | 7.0 | 6.0 | 1.0 | 13.0 | 12.0 | 16.0 | 3.0 | 15.0 | 11.0 | 0.0 | 4.0 | 0.0 | 7.0 | 8.0 | 0.0 | 5.0 | 1.0 | 6.0 | 6.0 | 1.0 | 0.0 | 12.0 | 9.0 | 1.0 | 6.0 | 4.0 | 2.0 | 5.0 | 1.0 | 0.0 | 5.0 | 5.0 | 6.0 | ... | 1.3393 | 0.5821 | 0.923077 | 0.071429 | 10.0 | 84.314 | 98.0 | 16.0 | 11.0 | 137.282 | 28.776 | 6.0 | 169.658 | 10.680 | 7.0 | 28.776 | 21.0 | 8.0 | 9.0 | 4469.86 | 658.85 | 1286.32 | 7736.05 | 418.80 | 3233.31 | 811.73 | 2321.61 | 182.82 | 283.84 | 3648.23 | 3141.25 | 356.67 | 237.25 | 283.84 | 3648.23 | 1195.37 | 535.42 | 1.7078 | 0.2798 | 0.3008 |
1 | 57.0 | 1.0 | 0.0 | 2.0 | 1.0 | 1.0 | 1.0 | 0.0 | 1.0 | 0.0 | 1.0 | 0.0 | 1.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | 0.0 | 2.0 | 0.0 | 1.0 | 0.0 | 0.0 | 0.0 | 1.0 | 1.0 | 0.0 | 2.0 | 1.0 | 1.0 | 1.0 | 0.0 | 3.0 | 1.0 | 0.0 | 1.0 | 0.0 | 1.0 | ... | 0.0000 | 0.0000 | 1.000000 | 0.000000 | 1.0 | 1.000 | 2.0 | 2.0 | 0.0 | 0.000 | 1.000 | 0.0 | 1.744 | 2.000 | 1.0 | 1.000 | 0.0 | 2.0 | 0.0 | 113.39 | 62.69 | 58.71 | 93.35 | 87.01 | 0.00 | 122.98 | 0.00 | 58.71 | 87.01 | 179.83 | 113.39 | 62.69 | 58.71 | 87.01 | 179.83 | 0.00 | 122.98 | 0.0000 | 0.0000 | 0.0000 |
2 | 38.0 | 7.0 | 0.0 | 15.0 | 4.0 | 9.0 | 5.0 | 9.0 | 14.0 | 7.0 | 6.0 | 10.0 | 14.0 | 5.0 | 11.0 | 0.0 | 3.0 | 2.0 | 2.0 | 0.0 | 3.0 | 2.0 | 1.0 | 1.0 | 0.0 | 0.0 | 2.0 | 6.0 | 0.0 | 9.0 | 2.0 | 5.0 | 1.0 | 7.0 | 7.0 | 8.0 | 3.0 | 2.0 | 6.0 | 6.0 | ... | 0.0000 | 0.7256 | 1.000000 | 0.250000 | 5.0 | 21.102 | 50.0 | 109.0 | 0.0 | 0.000 | 7.594 | 0.0 | 25.294 | 11.084 | 3.0 | 11.158 | 31.0 | 59.0 | 0.0 | 1564.91 | 971.09 | 177.93 | 3257.49 | 975.21 | 2555.27 | 6351.29 | 0.00 | 0.00 | 0.00 | 783.87 | 1239.19 | 533.46 | 83.37 | 593.13 | 1217.43 | 1336.83 | 3709.82 | 0.0000 | NaN | 0.0803 |
3 | 65.0 | 6.0 | 3.0 | 25.0 | 2.0 | 10.0 | 14.0 | 11.0 | 8.0 | 1.0 | 0.0 | 2.0 | 6.0 | 7.0 | 2.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | 5.0 | 0.0 | 0.0 | 1.0 | 0.0 | 0.0 | 0.0 | 2.0 | 1.0 | 11.0 | 2.0 | 3.0 | 5.0 | 5.0 | 4.0 | 2.0 | 1.0 | 0.0 | 1.0 | 3.0 | ... | 0.0000 | 0.0000 | 0.909091 | 0.000000 | 2.0 | 12.544 | 49.0 | 39.0 | 0.0 | 0.000 | 2.778 | 0.0 | 2.000 | 34.212 | 2.0 | 3.778 | 2.0 | 13.0 | 0.0 | 358.22 | 3798.18 | 680.93 | 1425.07 | 175.73 | 602.81 | 3544.76 | 0.00 | 119.99 | 73.24 | 346.74 | 139.68 | 1849.91 | 360.40 | 175.73 | 496.73 | 172.58 | 1246.21 | 0.0000 | 0.0000 | 0.0000 |
4 | 61.0 | 0.0 | 1.0 | 2.0 | 0.0 | 2.0 | 1.0 | 0.0 | 3.0 | 2.0 | 1.0 | 1.0 | 5.0 | 5.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | 1.0 | 1.0 | 0.0 | 0.0 | 2.0 | 0.0 | 0.0 | 1.0 | 0.0 | 1.0 | 2.0 | 0.0 | 2.0 | 1.0 | 0.0 | 8.0 | 2.0 | 2.0 | 1.0 | 1.0 | 4.0 | ... | 0.0000 | 0.7865 | 1.000000 | 0.100000 | 0.0 | 1.454 | 25.0 | 25.0 | 0.0 | 0.000 | 0.454 | 0.0 | 3.036 | 12.000 | 0.0 | 1.454 | 8.0 | 23.0 | 0.0 | 226.98 | 168.05 | 960.37 | 1560.21 | 0.00 | 342.45 | 1039.85 | 0.00 | 66.18 | 0.00 | 87.94 | 226.98 | 168.05 | 461.37 | 0.00 | 237.93 | 225.51 | 995.27 | 1.4142 | 0.3495 | 0.3495 |
687024 | 35.0 | 0.0 | 0.0 | 4.0 | 0.0 | 2.0 | 0.0 | 1.0 | 0.0 | 3.0 | 2.0 | 2.0 | 3.0 | 2.0 | 1.0 | 0.0 | 1.0 | 0.0 | 0.0 | 0.0 | 3.0 | 2.0 | 1.0 | 2.0 | 1.0 | 0.0 | 0.0 | 0.0 | 0.0 | 3.0 | 0.0 | 2.0 | 0.0 | 0.0 | 5.0 | 0.0 | 2.0 | 2.0 | 2.0 | 2.0 | ... | 1.3333 | 0.4002 | 0.000000 | 0.166667 | 0.0 | 3.000 | 14.0 | 2.0 | 0.0 | 19.856 | 3.000 | 0.0 | 19.856 | 29.000 | 0.0 | 3.000 | 15.0 | 1.0 | 0.0 | 550.09 | 695.32 | 111.87 | 114.21 | 0.00 | 1173.84 | 147.68 | 550.09 | 111.87 | 0.00 | 330.96 | 550.09 | 669.33 | 111.87 | 0.00 | 330.96 | 1173.84 | 119.99 | 2.6458 | 0.3646 | 0.3282 |
687025 | 33.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 2.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | NaN | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | 0.0 | 0.0 | 0.0 | 2.0 | ... | 0.0000 | 0.0000 | 1.000000 | 0.000000 | 0.0 | 0.000 | 1.0 | 1.0 | 0.0 | NaN | NaN | NaN | 0.000 | 0.000 | 0.0 | 0.000 | 0.0 | 1.0 | 0.0 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 28.01 | NaN | NaN | NaN | NaN | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 28.01 | 0.0000 | 0.0000 | 0.0000 |
687026 | 36.0 | 0.0 | 0.0 | 3.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | 0.0 | 1.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | ... | 0.0000 | 0.9847 | 1.000000 | 0.000000 | 0.0 | 0.000 | 5.0 | 3.0 | 0.0 | 0.000 | 0.000 | 0.0 | 0.000 | 0.000 | 0.0 | 0.000 | 15.0 | 0.0 | 0.0 | 0.00 | 155.97 | 23.99 | 41.51 | 0.00 | 615.77 | 87.47 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 449.01 | 0.00 | 0.0000 | NaN | NaN |
687027 | 37.0 | 0.0 | 1.0 | 2.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | 1.0 | 0.0 | 1.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | 0.0 | 1.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | ... | 0.0000 | 0.8318 | 1.000000 | 0.000000 | 0.0 | 0.000 | 1.0 | 0.0 | 0.0 | 0.000 | 0.000 | 0.0 | 0.000 | 0.476 | 0.0 | 0.000 | 0.0 | 0.0 | 0.0 | 0.00 | 81.90 | 29.82 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 46.72 | 0.00 | 0.00 | 0.00 | 0.00 | 0.00 | 0.0000 | NaN | NaN |
687028 | 40.0 | 0.0 | 1.0 | 0.0 | 0.0 | 2.0 | 0.0 | 0.0 | 2.0 | 2.0 | 2.0 | 2.0 | 3.0 | 1.0 | 1.0 | 2.0 | 1.0 | 4.0 | 0.0 | 1.0 | 0.0 | 1.0 | 0.0 | 0.0 | 1.0 | 1.0 | 3.0 | 0.0 | 1.0 | 0.0 | 0.0 | 1.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | 2.0 | 2.0 | ... | 0.0000 | 0.0000 | 1.000000 | 0.100000 | 0.0 | 6.452 | 25.0 | 17.0 | 3.0 | 6.660 | 1.344 | 1.0 | 6.660 | 0.000 | 0.0 | 1.344 | 18.0 | 4.0 | 1.0 | 531.25 | 0.00 | 0.00 | 916.44 | 0.00 | 2407.56 | 1304.03 | 290.01 | 0.00 | 0.00 | 228.47 | 290.01 | 0.00 | 0.00 | 0.00 | 228.47 | 752.32 | 596.86 | 0.0000 | 0.0000 | 0.0000 |
10 rows × 193 columns
treatment / control
¶import pandas as pd
pd.crosstab(dataset.treatment, dataset.target, normalize='index')
response_att | 0 | 1 |
---|---|---|
group | ||
control | 0.897421 | 0.102579 |
test | 0.889874 | 0.110126 |
# make treatment binary
treat_dict = {
'test': 1,
'control': 0
}
dataset.treatment = dataset.treatment.map(treat_dict)
# fill NaNs in the categorical feature `gender`
# for CatBoostClassifier
dataset.data['gender'] = dataset.data['gender'].fillna(value='Не определен')
print(dataset.data['gender'].value_counts(dropna=False))
Ж 433448 М 243910 Не определен 9671 Name: gender, dtype: int64
Intuition:
In a binary classification problem definition we stratify train set by splitting target 0/1
column. In uplift modeling we have two columns instead of one.
from sklearn.model_selection import train_test_split
stratify_cols = pd.concat([dataset.treatment, dataset.target], axis=1)
X_train, X_val, trmnt_train, trmnt_val, y_train, y_val = train_test_split(
dataset.data,
dataset.treatment,
dataset.target,
stratify=stratify_cols,
test_size=0.3,
random_state=42
)
print(f"Train shape: {X_train.shape}")
print(f"Validation shape: {X_val.shape}")
Train shape: (480920, 193) Validation shape: (206109, 193)
from catboost import CatBoostClassifier
from sklearn.base import clone
from sklift.models import TwoModels
from sklift.models import ClassTransformation
first_estimator = CatBoostClassifier(verbose=100,
task_type="GPU",
devices='0:1',
cat_features=['gender'],
random_state=42,
thread_count=1)
second_estimator = clone(first_estimator)
transform_model = ClassTransformation(estimator=first_estimator)
two_model = TwoModels(estimator_trmnt=first_estimator, estimator_ctrl=second_estimator)
transform_model = transform_model.fit(
X=X_train,
y=y_train,
treatment=trmnt_train
)
two_model = two_model.fit(
X=X_train,
y=y_train,
treatment=trmnt_train
)
Learning rate set to 0.024003 0: learn: 0.6893849 total: 59ms remaining: 58.9s 100: learn: 0.6100331 total: 5.39s remaining: 48s 200: learn: 0.6019326 total: 12s remaining: 47.8s 300: learn: 0.6000429 total: 18.7s remaining: 43.4s 400: learn: 0.5992161 total: 25.4s remaining: 37.9s 500: learn: 0.5986674 total: 32s remaining: 31.8s 600: learn: 0.5982996 total: 38.5s remaining: 25.5s 700: learn: 0.5980941 total: 44.8s remaining: 19.1s 800: learn: 0.5979237 total: 51.2s remaining: 12.7s 900: learn: 0.5976503 total: 57.5s remaining: 6.32s 999: learn: 0.5975015 total: 1m 3s remaining: 0us Learning rate set to 0.02591 0: learn: 0.6711650 total: 23.1ms remaining: 23.1s 100: learn: 0.2887976 total: 3.03s remaining: 27s 200: learn: 0.2763838 total: 6.53s remaining: 26s 300: learn: 0.2729584 total: 10.2s remaining: 23.7s 400: learn: 0.2713649 total: 13.9s remaining: 20.8s 500: learn: 0.2703728 total: 17.6s remaining: 17.6s 600: learn: 0.2696703 total: 21.3s remaining: 14.1s 700: learn: 0.2691328 total: 24.9s remaining: 10.6s 800: learn: 0.2686616 total: 28.6s remaining: 7.11s 900: learn: 0.2682632 total: 32.3s remaining: 3.55s 999: learn: 0.2678762 total: 36s remaining: 0us Learning rate set to 0.024384 0: learn: 0.6735712 total: 44.9ms remaining: 44.9s 100: learn: 0.3063022 total: 4.82s remaining: 42.9s 200: learn: 0.2925770 total: 10.2s remaining: 40.4s 300: learn: 0.2895685 total: 15.6s remaining: 36.3s 400: learn: 0.2880540 total: 21.3s remaining: 31.9s 500: learn: 0.2872389 total: 26.9s remaining: 26.8s 600: learn: 0.2866951 total: 32.6s remaining: 21.6s 700: learn: 0.2863474 total: 38.1s remaining: 16.3s 800: learn: 0.2860138 total: 43.6s remaining: 10.8s 900: learn: 0.2857359 total: 49.2s remaining: 5.41s 999: learn: 0.2854954 total: 54.8s remaining: 0us
uplift_transform_model_val = transform_model.predict(X_val)
uplift_transform_model_train = transform_model.predict(X_train)
uplift_two_model = two_model.predict(X_val)
uplift@k
¶uplift@k
= target mean at k% in the treatment group
- target mean at k% in the control group
¶How to count uplift@k
:
Code parameter options:
strategy='overall'
- sort by uplift treatment and control togetherstrategy='by_group'
- sort by uplift treatment and control separately🚀uplift@k with a small step ot the k parameter
¶import matplotlib.pyplot as plt
import numpy as np
from sklift.metrics import uplift_at_k
values_uplift_k_transform = []
values_uplift_k_two = []
values_k = []
for k in np.arange(0.01,1,0.01):
values_uplift_k_transform.append(uplift_at_k(y_val, uplift_transform_model_val, trmnt_val, strategy='overall', k=k))
values_uplift_k_two.append(uplift_at_k(y_val, uplift_two_model, trmnt_val, strategy='overall', k=k))
values_k.append(k)
For ClassTransformation model
¶plt.plot(values_k, values_uplift_k_transform)
plt.title('Dependence of uplift@k on k')
plt.xlabel('The value of k')
plt.ylabel('The value of uplift@k')
plt.show()
For TwoModels
¶plt.plot(values_k, values_uplift_k_two)
plt.title('Dependence of uplift@k on k')
plt.xlabel('The value of k')
plt.ylabel('The value of uplift@k')
plt.show()
ASD metric
¶The average squared deviation (ASD) is a model stability metric that shows how much the model overfits the training data. Larger values of ASD mean greater overfit.
¶strategy='overall'
- The first step is taking the first k observations of all test data ordered by uplift prediction (overall both groups - control and treatment) and conversions in treatment and control groups calculated only on them. Then the difference between these conversions is calculated.strategy='by_group'
- Separately calculates conversions in top k observations in each group (control and treatment) sorted by uplift predictions. Then the difference between these conversions is calculatedbins=10
- Determines the number of bins (and the relative percentile) in the data.from sklift.metrics import average_squared_deviation
asd_overall = average_squared_deviation(y_train, uplift_transform_model_train, trmnt_train, y_val,
uplift_transform_model_val, trmnt_val, strategy='overall')
asd_by_group = average_squared_deviation(y_train, uplift_transform_model_train, trmnt_train, y_val,
uplift_transform_model_val, trmnt_val, strategy='by_group')
print(f"average squared deviation by overall strategy for the ClassTransformation model: {asd_overall:.6f}")
print(f"average squared deviation by group strategy for the ClassTransformation model: {asd_by_group:.6f}")
average squared deviation by overall strategy for the ClassTransformation model: 0.000007 average squared deviation by group strategy for the ClassTransformation model: 0.000011
↗️Display 2 different model uplift scores on one qini plot
¶Only qiwi curves
¶from sklift.viz import plot_qini_curve
fig, ax_roc = plt.subplots(1, 1)
plot_qini_curve(y_val, uplift_transform_model_val, trmnt_val, name='Transform model', random=False, perfect=False, ax=ax_roc)
plot_qini_curve(y_val, uplift_two_model, trmnt_val, name='Two models', random=False, perfect=False, ax=ax_roc)
<sklift.viz.base.UpliftCurveDisplay at 0x7fe6cbfe8710>
Qini curves with a random curve and with a perfect curve
¶fig, ax_roc = plt.subplots(1, 1)
plot_qini_curve(y_val, uplift_transform_model_val, trmnt_val, name='Transform model', random=True, perfect=True, ax=ax_roc)
plot_qini_curve(y_val, uplift_two_model, trmnt_val, name='Two models', random=True, perfect=True, ax=ax_roc)
<sklift.viz.base.UpliftCurveDisplay at 0x7fe6cb0e9710>