import numpy as np
import pandas as pd
import xgboost as xgb
import lightgbm as lgb
import matplotlib.pyplot as plt
import seaborn
from sklearn import preprocessing
import time
%matplotlib inline
train = pd.read_csv('input/otto_train.csv')
print(train.shape)
(61878, 95)
def encode_features(dat):
df = pd.DataFrame(index=dat.index.values)
for c in dat.columns.values:
unq = np.unique(dat[c])
arr = np.zeros(len(df))
for ii, u in enumerate(unq):
flg = (dat[c] == u).values
arr[flg] = ii
df[c] = arr.astype(int)
return df
x = encode_features(train.drop(['id', 'target'], axis=1))
y = np.array([int(v.split('_')[1])-1 for v in train.target])
print(x.shape, y.shape)
(61878, 93) (61878,)
num_cls = len(np.unique(y))
print(num_cls)
9
prm_xgb = {
'booster': 'gbtree',
'objective': 'multi:softprob',
'num_class': num_cls,
'max_depth': 5,
'learning_rate': 0.1,
'colsample_bytree': 0.9,
'subsample': 0.9,
'eval_metric': 'mlogloss',
}
prm_lgb = {
'boosting_type': 'gbdt',
'objective': 'multiclass',
'num_class': num_cls,
'num_leaves' : 2**5-1,
'learning_rate': 0.1,
'feature_fraction': 0.9,
'bagging_fraction': 0.9,
'bagging_freq' : 1,
'metric': 'multi_logloss',
}
num_round = 100
np.random.seed(20161218)
flg_train = np.random.choice([False, True], len(y), p=[0.3, 0.7])
flg_valid = np.logical_not(flg_train)
dt_xgb = xgb.DMatrix(x[flg_train], y[flg_train])
dv_xgb = xgb.DMatrix(x[flg_valid], y[flg_valid])
dt_lgb = lgb.Dataset(x[flg_train], y[flg_train])
dv_lgb = lgb.Dataset(x[flg_valid], y[flg_valid], reference=dt_lgb)
dt_lgb_c = lgb.Dataset(x[flg_train], y[flg_train], free_raw_data=False)
dv_lgb_c = lgb.Dataset(x[flg_valid], y[flg_valid], free_raw_data=False,
reference=dt_lgb)
time_s = time.time()
obj_xgb = xgb.train(
prm_xgb, dt_xgb, num_round,
[(dt_xgb, 'train'), (dv_xgb, 'valid')])
time_t = time.time()
print(time_t - time_s)
[0] train-mlogloss:1.97769 valid-mlogloss:1.98167 [1] train-mlogloss:1.81571 valid-mlogloss:1.82305 [2] train-mlogloss:1.68693 valid-mlogloss:1.69735 [3] train-mlogloss:1.58784 valid-mlogloss:1.60077 [4] train-mlogloss:1.49629 valid-mlogloss:1.51075 [5] train-mlogloss:1.41958 valid-mlogloss:1.43592 [6] train-mlogloss:1.35004 valid-mlogloss:1.368 [7] train-mlogloss:1.28977 valid-mlogloss:1.30914 [8] train-mlogloss:1.23794 valid-mlogloss:1.25887 [9] train-mlogloss:1.18878 valid-mlogloss:1.21078 [10] train-mlogloss:1.14482 valid-mlogloss:1.16806 [11] train-mlogloss:1.10467 valid-mlogloss:1.12878 [12] train-mlogloss:1.06779 valid-mlogloss:1.09323 [13] train-mlogloss:1.03423 valid-mlogloss:1.06093 [14] train-mlogloss:1.00338 valid-mlogloss:1.03145 [15] train-mlogloss:0.975446 valid-mlogloss:1.00447 [16] train-mlogloss:0.950178 valid-mlogloss:0.980001 [17] train-mlogloss:0.927154 valid-mlogloss:0.957823 [18] train-mlogloss:0.904816 valid-mlogloss:0.936438 [19] train-mlogloss:0.88535 valid-mlogloss:0.917764 [20] train-mlogloss:0.866943 valid-mlogloss:0.900236 [21] train-mlogloss:0.848689 valid-mlogloss:0.882548 [22] train-mlogloss:0.831932 valid-mlogloss:0.866694 [23] train-mlogloss:0.817331 valid-mlogloss:0.852802 [24] train-mlogloss:0.802325 valid-mlogloss:0.838711 [25] train-mlogloss:0.788303 valid-mlogloss:0.82546 [26] train-mlogloss:0.774821 valid-mlogloss:0.812747 [27] train-mlogloss:0.762089 valid-mlogloss:0.800748 [28] train-mlogloss:0.750976 valid-mlogloss:0.790482 [29] train-mlogloss:0.739906 valid-mlogloss:0.780105 [30] train-mlogloss:0.72947 valid-mlogloss:0.770532 [31] train-mlogloss:0.719581 valid-mlogloss:0.761236 [32] train-mlogloss:0.710496 valid-mlogloss:0.752896 [33] train-mlogloss:0.701718 valid-mlogloss:0.744823 [34] train-mlogloss:0.693449 valid-mlogloss:0.737215 [35] train-mlogloss:0.686012 valid-mlogloss:0.730414 [36] train-mlogloss:0.6786 valid-mlogloss:0.723613 [37] train-mlogloss:0.671537 valid-mlogloss:0.717258 [38] train-mlogloss:0.664872 valid-mlogloss:0.711308 [39] train-mlogloss:0.65827 valid-mlogloss:0.70538 [40] train-mlogloss:0.652491 valid-mlogloss:0.700172 [41] train-mlogloss:0.646292 valid-mlogloss:0.694649 [42] train-mlogloss:0.640946 valid-mlogloss:0.68997 [43] train-mlogloss:0.635377 valid-mlogloss:0.685187 [44] train-mlogloss:0.629765 valid-mlogloss:0.680131 [45] train-mlogloss:0.624747 valid-mlogloss:0.675822 [46] train-mlogloss:0.620059 valid-mlogloss:0.671835 [47] train-mlogloss:0.615438 valid-mlogloss:0.667818 [48] train-mlogloss:0.61111 valid-mlogloss:0.664124 [49] train-mlogloss:0.606787 valid-mlogloss:0.660404 [50] train-mlogloss:0.602591 valid-mlogloss:0.65673 [51] train-mlogloss:0.598151 valid-mlogloss:0.65303 [52] train-mlogloss:0.594195 valid-mlogloss:0.64961 [53] train-mlogloss:0.590399 valid-mlogloss:0.646296 [54] train-mlogloss:0.586551 valid-mlogloss:0.64311 [55] train-mlogloss:0.583053 valid-mlogloss:0.640342 [56] train-mlogloss:0.579582 valid-mlogloss:0.637335 [57] train-mlogloss:0.576264 valid-mlogloss:0.634617 [58] train-mlogloss:0.573013 valid-mlogloss:0.631998 [59] train-mlogloss:0.569984 valid-mlogloss:0.629461 [60] train-mlogloss:0.567038 valid-mlogloss:0.627098 [61] train-mlogloss:0.56438 valid-mlogloss:0.624942 [62] train-mlogloss:0.561487 valid-mlogloss:0.622638 [63] train-mlogloss:0.558694 valid-mlogloss:0.620354 [64] train-mlogloss:0.55589 valid-mlogloss:0.618263 [65] train-mlogloss:0.553144 valid-mlogloss:0.616094 [66] train-mlogloss:0.550565 valid-mlogloss:0.614126 [67] train-mlogloss:0.54817 valid-mlogloss:0.612366 [68] train-mlogloss:0.545968 valid-mlogloss:0.610506 [69] train-mlogloss:0.543221 valid-mlogloss:0.608496 [70] train-mlogloss:0.541326 valid-mlogloss:0.606942 [71] train-mlogloss:0.539303 valid-mlogloss:0.605424 [72] train-mlogloss:0.536813 valid-mlogloss:0.603583 [73] train-mlogloss:0.534301 valid-mlogloss:0.601599 [74] train-mlogloss:0.532282 valid-mlogloss:0.600154 [75] train-mlogloss:0.530209 valid-mlogloss:0.598514 [76] train-mlogloss:0.528042 valid-mlogloss:0.596927 [77] train-mlogloss:0.526417 valid-mlogloss:0.595635 [78] train-mlogloss:0.524277 valid-mlogloss:0.594087 [79] train-mlogloss:0.522304 valid-mlogloss:0.592697 [80] train-mlogloss:0.520434 valid-mlogloss:0.591274 [81] train-mlogloss:0.518533 valid-mlogloss:0.589946 [82] train-mlogloss:0.516775 valid-mlogloss:0.588603 [83] train-mlogloss:0.514891 valid-mlogloss:0.587333 [84] train-mlogloss:0.513171 valid-mlogloss:0.586085 [85] train-mlogloss:0.511336 valid-mlogloss:0.584773 [86] train-mlogloss:0.509759 valid-mlogloss:0.583657 [87] train-mlogloss:0.508198 valid-mlogloss:0.582465 [88] train-mlogloss:0.506574 valid-mlogloss:0.581387 [89] train-mlogloss:0.50486 valid-mlogloss:0.580269 [90] train-mlogloss:0.503117 valid-mlogloss:0.579108 [91] train-mlogloss:0.501639 valid-mlogloss:0.578292 [92] train-mlogloss:0.500083 valid-mlogloss:0.577266 [93] train-mlogloss:0.498588 valid-mlogloss:0.576308 [94] train-mlogloss:0.496938 valid-mlogloss:0.575252 [95] train-mlogloss:0.495597 valid-mlogloss:0.574365 [96] train-mlogloss:0.493962 valid-mlogloss:0.573261 [97] train-mlogloss:0.492771 valid-mlogloss:0.572487 [98] train-mlogloss:0.490992 valid-mlogloss:0.571281 [99] train-mlogloss:0.489677 valid-mlogloss:0.570366 65.19904208183289
time_s = time.time()
obj_lgb = lgb.train(
prm_lgb, dt_lgb, num_boost_round=num_round,
valid_sets=dv_lgb)
time_t = time.time()
print(time_t - time_s)
obj_lgb.save_model('output/lgb.txt')
[1] valid_0's multi_logloss:1.96589 [2] valid_0's multi_logloss:1.80069 [3] valid_0's multi_logloss:1.67048 [4] valid_0's multi_logloss:1.56283 [5] valid_0's multi_logloss:1.47489 [6] valid_0's multi_logloss:1.39621 [7] valid_0's multi_logloss:1.32853 [8] valid_0's multi_logloss:1.26781 [9] valid_0's multi_logloss:1.21342 [10] valid_0's multi_logloss:1.16519 [11] valid_0's multi_logloss:1.12163 [12] valid_0's multi_logloss:1.08238 [13] valid_0's multi_logloss:1.0463 [14] valid_0's multi_logloss:1.01403 [15] valid_0's multi_logloss:0.984063 [16] valid_0's multi_logloss:0.956702 [17] valid_0's multi_logloss:0.931502 [18] valid_0's multi_logloss:0.909192 [19] valid_0's multi_logloss:0.887876 [20] valid_0's multi_logloss:0.868728 [21] valid_0's multi_logloss:0.850413 [22] valid_0's multi_logloss:0.833369 [23] valid_0's multi_logloss:0.817609 [24] valid_0's multi_logloss:0.802913 [25] valid_0's multi_logloss:0.789397 [26] valid_0's multi_logloss:0.776372 [27] valid_0's multi_logloss:0.76423 [28] valid_0's multi_logloss:0.752673 [29] valid_0's multi_logloss:0.741827 [30] valid_0's multi_logloss:0.731737 [31] valid_0's multi_logloss:0.722051 [32] valid_0's multi_logloss:0.713151 [33] valid_0's multi_logloss:0.704391 [34] valid_0's multi_logloss:0.696632 [35] valid_0's multi_logloss:0.688742 [36] valid_0's multi_logloss:0.681658 [37] valid_0's multi_logloss:0.675085 [38] valid_0's multi_logloss:0.668548 [39] valid_0's multi_logloss:0.66241 [40] valid_0's multi_logloss:0.656508 [41] valid_0's multi_logloss:0.650796 [42] valid_0's multi_logloss:0.64544 [43] valid_0's multi_logloss:0.640326 [44] valid_0's multi_logloss:0.635305 [45] valid_0's multi_logloss:0.630602 [46] valid_0's multi_logloss:0.626089 [47] valid_0's multi_logloss:0.621855 [48] valid_0's multi_logloss:0.617875 [49] valid_0's multi_logloss:0.613931 [50] valid_0's multi_logloss:0.610281 [51] valid_0's multi_logloss:0.606787 [52] valid_0's multi_logloss:0.603244 [53] valid_0's multi_logloss:0.60016 [54] valid_0's multi_logloss:0.596772 [55] valid_0's multi_logloss:0.593876 [56] valid_0's multi_logloss:0.590949 [57] valid_0's multi_logloss:0.587967 [58] valid_0's multi_logloss:0.585057 [59] valid_0's multi_logloss:0.582379 [60] valid_0's multi_logloss:0.579821 [61] valid_0's multi_logloss:0.577179 [62] valid_0's multi_logloss:0.574902 [63] valid_0's multi_logloss:0.572569 [64] valid_0's multi_logloss:0.570239 [65] valid_0's multi_logloss:0.568136 [66] valid_0's multi_logloss:0.566206 [67] valid_0's multi_logloss:0.564028 [68] valid_0's multi_logloss:0.562047 [69] valid_0's multi_logloss:0.560237 [70] valid_0's multi_logloss:0.558436 [71] valid_0's multi_logloss:0.556645 [72] valid_0's multi_logloss:0.554881 [73] valid_0's multi_logloss:0.553293 [74] valid_0's multi_logloss:0.551645 [75] valid_0's multi_logloss:0.550112 [76] valid_0's multi_logloss:0.54849 [77] valid_0's multi_logloss:0.546681 [78] valid_0's multi_logloss:0.54531 [79] valid_0's multi_logloss:0.543805 [80] valid_0's multi_logloss:0.542504 [81] valid_0's multi_logloss:0.541192 [82] valid_0's multi_logloss:0.539905 [83] valid_0's multi_logloss:0.538626 [84] valid_0's multi_logloss:0.537395 [85] valid_0's multi_logloss:0.536271 [86] valid_0's multi_logloss:0.5353 [87] valid_0's multi_logloss:0.534155 [88] valid_0's multi_logloss:0.53317 [89] valid_0's multi_logloss:0.532131 [90] valid_0's multi_logloss:0.531052 [91] valid_0's multi_logloss:0.530085 [92] valid_0's multi_logloss:0.529274 [93] valid_0's multi_logloss:0.528285 [94] valid_0's multi_logloss:0.527422 [95] valid_0's multi_logloss:0.526541 [96] valid_0's multi_logloss:0.525695 [97] valid_0's multi_logloss:0.524851 [98] valid_0's multi_logloss:0.523952 [99] valid_0's multi_logloss:0.523016 [100] valid_0's multi_logloss:0.522378 16.02284002304077
time_s = time.time()
obj_lgb = lgb.train(
prm_lgb, dt_lgb_c, num_boost_round=num_round,
valid_sets=dv_lgb_c,
categorical_feature=list(range(len(x.columns.values))))
time_t = time.time()
print(time_t - time_s)
obj_lgb.save_model('output/lgb_cat.txt')
[1] valid_0's multi_logloss:2.01067 [2] valid_0's multi_logloss:1.87489 [3] valid_0's multi_logloss:1.76438 [4] valid_0's multi_logloss:1.67353 [5] valid_0's multi_logloss:1.59879 [6] valid_0's multi_logloss:1.53027 [7] valid_0's multi_logloss:1.47235 [8] valid_0's multi_logloss:1.41975 [9] valid_0's multi_logloss:1.37189 [10] valid_0's multi_logloss:1.32934 [11] valid_0's multi_logloss:1.29121 [12] valid_0's multi_logloss:1.25617 [13] valid_0's multi_logloss:1.22375 [14] valid_0's multi_logloss:1.19445 [15] valid_0's multi_logloss:1.1664 [16] valid_0's multi_logloss:1.14143 [17] valid_0's multi_logloss:1.11888 [18] valid_0's multi_logloss:1.09807 [19] valid_0's multi_logloss:1.07792 [20] valid_0's multi_logloss:1.05924 [21] valid_0's multi_logloss:1.04119 [22] valid_0's multi_logloss:1.02435 [23] valid_0's multi_logloss:1.00902 [24] valid_0's multi_logloss:0.994402 [25] valid_0's multi_logloss:0.981219 [26] valid_0's multi_logloss:0.967938 [27] valid_0's multi_logloss:0.955806 [28] valid_0's multi_logloss:0.943997 [29] valid_0's multi_logloss:0.93329 [30] valid_0's multi_logloss:0.92275 [31] valid_0's multi_logloss:0.912987 [32] valid_0's multi_logloss:0.903436 [33] valid_0's multi_logloss:0.893949 [34] valid_0's multi_logloss:0.886012 [35] valid_0's multi_logloss:0.878008 [36] valid_0's multi_logloss:0.869912 [37] valid_0's multi_logloss:0.862095 [38] valid_0's multi_logloss:0.855012 [39] valid_0's multi_logloss:0.847866 [40] valid_0's multi_logloss:0.840964 [41] valid_0's multi_logloss:0.834674 [42] valid_0's multi_logloss:0.828241 [43] valid_0's multi_logloss:0.822593 [44] valid_0's multi_logloss:0.816565 [45] valid_0's multi_logloss:0.811401 [46] valid_0's multi_logloss:0.805926 [47] valid_0's multi_logloss:0.801008 [48] valid_0's multi_logloss:0.796294 [49] valid_0's multi_logloss:0.791922 [50] valid_0's multi_logloss:0.786993 [51] valid_0's multi_logloss:0.782985 [52] valid_0's multi_logloss:0.778408 [53] valid_0's multi_logloss:0.774352 [54] valid_0's multi_logloss:0.770591 [55] valid_0's multi_logloss:0.766681 [56] valid_0's multi_logloss:0.762936 [57] valid_0's multi_logloss:0.759445 [58] valid_0's multi_logloss:0.755821 [59] valid_0's multi_logloss:0.752613 [60] valid_0's multi_logloss:0.749157 [61] valid_0's multi_logloss:0.745849 [62] valid_0's multi_logloss:0.742746 [63] valid_0's multi_logloss:0.739965 [64] valid_0's multi_logloss:0.737107 [65] valid_0's multi_logloss:0.734446 [66] valid_0's multi_logloss:0.731742 [67] valid_0's multi_logloss:0.729159 [68] valid_0's multi_logloss:0.726514 [69] valid_0's multi_logloss:0.723947 [70] valid_0's multi_logloss:0.721495 [71] valid_0's multi_logloss:0.719146 [72] valid_0's multi_logloss:0.716951 [73] valid_0's multi_logloss:0.714741 [74] valid_0's multi_logloss:0.712736 [75] valid_0's multi_logloss:0.710544 [76] valid_0's multi_logloss:0.708489 [77] valid_0's multi_logloss:0.706578 [78] valid_0's multi_logloss:0.704697 [79] valid_0's multi_logloss:0.702595 [80] valid_0's multi_logloss:0.700678 [81] valid_0's multi_logloss:0.698942 [82] valid_0's multi_logloss:0.697189 [83] valid_0's multi_logloss:0.695499 [84] valid_0's multi_logloss:0.693559 [85] valid_0's multi_logloss:0.692008 [86] valid_0's multi_logloss:0.690307 [87] valid_0's multi_logloss:0.688656 [88] valid_0's multi_logloss:0.68721 [89] valid_0's multi_logloss:0.685759 [90] valid_0's multi_logloss:0.684212 [91] valid_0's multi_logloss:0.682738 [92] valid_0's multi_logloss:0.681347 [93] valid_0's multi_logloss:0.680014 [94] valid_0's multi_logloss:0.678764 [95] valid_0's multi_logloss:0.677424 [96] valid_0's multi_logloss:0.676253 [97] valid_0's multi_logloss:0.67492 [98] valid_0's multi_logloss:0.673824 [99] valid_0's multi_logloss:0.672577 [100] valid_0's multi_logloss:0.671286 15.776311874389648