12 Feb 2018, marugari
PART (Peeking Additive Regression Trees) aims to
For training a PART booster, we need split training data into 3 part.
Repository (https://github.com/marugari/LightGBM/tree/part)
This is implemented as a LightGBM custom booster. The following is a fork of the Kaggle Zillow Prize Kernel.
import numpy as np
import pandas as pd
import lightgbm as lgb
import gc
train = pd.read_csv('input/zillow/train_2016_v2.csv', engine='python')
prop = pd.read_csv('input/zillow/properties_2016.csv', engine='python')
for c, dtype in zip(prop.columns, prop.dtypes):
if dtype == np.float64:
prop[c] = prop[c].astype(np.float32)
df_train = train.merge(prop, how='left', on='parcelid')
col = [
'parcelid',
'logerror',
'transactiondate',
'propertyzoningdesc',
'propertycountylandusecode'
]
x_train = df_train.drop(col, axis=1)
y_train = df_train['logerror'].values
print(x_train.shape, y_train.shape)
train_columns = x_train.columns
for c in x_train.dtypes[x_train.dtypes == object].index.values:
x_train[c] = (x_train[c] == True)
del df_train
(90275, 55) (90275,)
split = 80000
xt, xv = x_train[:split], x_train[split:]
xt = xt.values.astype(np.float32, copy=False)
xv = xv.values.astype(np.float32, copy=False)
yt, yv = y_train[:split], y_train[split:]
ds_train = lgb.Dataset(xt, label=yt, free_raw_data=False)
ds_valid = lgb.Dataset(xv, label=yv, free_raw_data=False)
prm = {
'learning_rate': 0.002,
'boosting_type': 'gbdt',
'objective': 'regression',
'metric': 'mae',
'sub_feature': 0.5,
'num_leaves': 60,
'min_data': 500,
'min_hessian': 1,
}
num_round = 500
clf_gbdt = lgb.train(prm, ds_train, num_round)
prm_part = prm
prm_part['boosting_type'] = 'part'
prm_part['learning_rate'] = 0.002
prm_part['drop_rate'] = 0.0
prm_part['skip_drop'] = 0.0
np.random.seed(20180212)
flg_part = np.random.choice([True, False], len(yt), replace=True, p=[0.7, 0.3])
flg_peek = np.logical_not(flg_part)
ds_part = lgb.Dataset(xt[flg_part], label=yt[flg_part], free_raw_data=False)
ds_peek = lgb.Dataset(xt[flg_peek], label=yt[flg_peek], free_raw_data=False)
clf_part = lgb.train(prm_part, ds_part, num_round, valid_sets=ds_peek)
[1] valid_0's l1: 0.0683414 [2] valid_0's l1: 0.0683379 [3] valid_0's l1: 0.0683343 [4] valid_0's l1: 0.068331 [5] valid_0's l1: 0.0683291 [6] valid_0's l1: 0.0683264 [7] valid_0's l1: 0.0683249 [8] valid_0's l1: 0.0683225 [9] valid_0's l1: 0.06832 [10] valid_0's l1: 0.0683163 [11] valid_0's l1: 0.0683139 [12] valid_0's l1: 0.0683106 [13] valid_0's l1: 0.0683076 [14] valid_0's l1: 0.0683049 [15] valid_0's l1: 0.0683014 [16] valid_0's l1: 0.0682984 [17] valid_0's l1: 0.0682964 [18] valid_0's l1: 0.0682937 [19] valid_0's l1: 0.0682904 [20] valid_0's l1: 0.0682873 [21] valid_0's l1: 0.0682854 [22] valid_0's l1: 0.0682819 [23] valid_0's l1: 0.0682799 [24] valid_0's l1: 0.068277 [25] valid_0's l1: 0.0682755 [26] valid_0's l1: 0.0682727 [27] valid_0's l1: 0.0682709 [28] valid_0's l1: 0.0682689 [29] valid_0's l1: 0.068267 [30] valid_0's l1: 0.0682641 [31] valid_0's l1: 0.0682614 [32] valid_0's l1: 0.0682589 [33] valid_0's l1: 0.0682562 [34] valid_0's l1: 0.0682533 [35] valid_0's l1: 0.0682514 [36] valid_0's l1: 0.0682481 [37] valid_0's l1: 0.0682448 [38] valid_0's l1: 0.0682428 [39] valid_0's l1: 0.0682411 [40] valid_0's l1: 0.0682378 [41] valid_0's l1: 0.0682362 [42] valid_0's l1: 0.0682346 [43] valid_0's l1: 0.0682316 [44] valid_0's l1: 0.0682295 [45] valid_0's l1: 0.0682262 [46] valid_0's l1: 0.0682243 [47] valid_0's l1: 0.0682211 [48] valid_0's l1: 0.0682192 [49] valid_0's l1: 0.0682163 [50] valid_0's l1: 0.068214 [51] valid_0's l1: 0.0682118 [52] valid_0's l1: 0.0682092 [53] valid_0's l1: 0.068208 [54] valid_0's l1: 0.0682054 [55] valid_0's l1: 0.068203 [56] valid_0's l1: 0.0682009 [57] valid_0's l1: 0.0681992 [58] valid_0's l1: 0.0681973 [59] valid_0's l1: 0.0681945 [60] valid_0's l1: 0.0681932 [61] valid_0's l1: 0.0681908 [62] valid_0's l1: 0.0681888 [63] valid_0's l1: 0.0681869 [64] valid_0's l1: 0.0681849 [65] valid_0's l1: 0.0681839 [66] valid_0's l1: 0.0681821 [67] valid_0's l1: 0.06818 [68] valid_0's l1: 0.0681773 [69] valid_0's l1: 0.0681753 [70] valid_0's l1: 0.0681734 [71] valid_0's l1: 0.0681712 [72] valid_0's l1: 0.0681695 [73] valid_0's l1: 0.068168 [74] valid_0's l1: 0.0681664 [75] valid_0's l1: 0.0681642 [76] valid_0's l1: 0.0681616 [77] valid_0's l1: 0.068159 [78] valid_0's l1: 0.0681568 [79] valid_0's l1: 0.0681557 [80] valid_0's l1: 0.068154 [81] valid_0's l1: 0.0681519 [82] valid_0's l1: 0.0681503 [83] valid_0's l1: 0.0681483 [84] valid_0's l1: 0.0681466 [85] valid_0's l1: 0.068144 [86] valid_0's l1: 0.068142 [87] valid_0's l1: 0.0681402 [88] valid_0's l1: 0.0681379 [89] valid_0's l1: 0.0681364 [90] valid_0's l1: 0.0681347 [91] valid_0's l1: 0.0681329 [92] valid_0's l1: 0.0681311 [93] valid_0's l1: 0.0681291 [94] valid_0's l1: 0.0681269 [95] valid_0's l1: 0.0681248 [96] valid_0's l1: 0.0681233 [97] valid_0's l1: 0.0681207 [98] valid_0's l1: 0.0681182 [99] valid_0's l1: 0.0681167 [100] valid_0's l1: 0.068115 [101] valid_0's l1: 0.0681135 [102] valid_0's l1: 0.0681127 [103] valid_0's l1: 0.0681113 [104] valid_0's l1: 0.0681093 [105] valid_0's l1: 0.0681076 [106] valid_0's l1: 0.0681069 [107] valid_0's l1: 0.0681051 [108] valid_0's l1: 0.0681029 [109] valid_0's l1: 0.0681011 [110] valid_0's l1: 0.068099 [111] valid_0's l1: 0.0680975 [112] valid_0's l1: 0.0680951 [113] valid_0's l1: 0.068093 [114] valid_0's l1: 0.0680912 [115] valid_0's l1: 0.0680896 [116] valid_0's l1: 0.0680887 [117] valid_0's l1: 0.0680876 [118] valid_0's l1: 0.068086 [119] valid_0's l1: 0.068084 [120] valid_0's l1: 0.0680817 [121] valid_0's l1: 0.0680803 [122] valid_0's l1: 0.0680782 [123] valid_0's l1: 0.0680766 [124] valid_0's l1: 0.0680747 [125] valid_0's l1: 0.068073 [126] valid_0's l1: 0.0680719 [127] valid_0's l1: 0.0680702 [128] valid_0's l1: 0.0680692 [129] valid_0's l1: 0.0680678 [130] valid_0's l1: 0.0680666 [131] valid_0's l1: 0.0680654 [132] valid_0's l1: 0.0680643 [133] valid_0's l1: 0.0680626 [134] valid_0's l1: 0.0680606 [135] valid_0's l1: 0.0680589 [136] valid_0's l1: 0.0680576 [137] valid_0's l1: 0.0680556 [138] valid_0's l1: 0.0680547 [139] valid_0's l1: 0.0680536 [140] valid_0's l1: 0.0680521 [141] valid_0's l1: 0.0680502 [142] valid_0's l1: 0.068049 [143] valid_0's l1: 0.0680474 [144] valid_0's l1: 0.0680462 [145] valid_0's l1: 0.0680447 [146] valid_0's l1: 0.0680432 [147] valid_0's l1: 0.068042 [148] valid_0's l1: 0.0680408 [149] valid_0's l1: 0.0680398 [150] valid_0's l1: 0.0680388 [151] valid_0's l1: 0.068038 [152] valid_0's l1: 0.0680362 [153] valid_0's l1: 0.0680349 [154] valid_0's l1: 0.0680338 [155] valid_0's l1: 0.068032 [156] valid_0's l1: 0.0680314 [157] valid_0's l1: 0.0680303 [158] valid_0's l1: 0.0680284 [159] valid_0's l1: 0.0680273 [160] valid_0's l1: 0.0680254 [161] valid_0's l1: 0.068024 [162] valid_0's l1: 0.068023 [163] valid_0's l1: 0.0680221 [164] valid_0's l1: 0.0680214 [165] valid_0's l1: 0.0680201 [166] valid_0's l1: 0.068019 [167] valid_0's l1: 0.0680168 [168] valid_0's l1: 0.0680159 [169] valid_0's l1: 0.0680149 [170] valid_0's l1: 0.0680135 [171] valid_0's l1: 0.0680126 [172] valid_0's l1: 0.0680116 [173] valid_0's l1: 0.0680097 [174] valid_0's l1: 0.0680083 [175] valid_0's l1: 0.0680072 [176] valid_0's l1: 0.0680061 [177] valid_0's l1: 0.0680045 [178] valid_0's l1: 0.0680033 [179] valid_0's l1: 0.068002 [180] valid_0's l1: 0.0680012 [181] valid_0's l1: 0.0679996 [182] valid_0's l1: 0.0679988 [183] valid_0's l1: 0.0679978 [184] valid_0's l1: 0.0679963 [185] valid_0's l1: 0.0679949 [186] valid_0's l1: 0.0679938 [187] valid_0's l1: 0.0679921 [188] valid_0's l1: 0.0679913 [189] valid_0's l1: 0.0679897 [190] valid_0's l1: 0.0679884 [191] valid_0's l1: 0.0679876 [192] valid_0's l1: 0.0679867 [193] valid_0's l1: 0.067986 [194] valid_0's l1: 0.067985 [195] valid_0's l1: 0.0679837 [196] valid_0's l1: 0.0679831 [197] valid_0's l1: 0.067982 [198] valid_0's l1: 0.0679806 [199] valid_0's l1: 0.0679798 [200] valid_0's l1: 0.0679792 [201] valid_0's l1: 0.0679783 [202] valid_0's l1: 0.0679775 [203] valid_0's l1: 0.0679762 [204] valid_0's l1: 0.0679754 [205] valid_0's l1: 0.0679747 [206] valid_0's l1: 0.0679734 [207] valid_0's l1: 0.0679719 [208] valid_0's l1: 0.0679705 [209] valid_0's l1: 0.0679694 [210] valid_0's l1: 0.0679686 [211] valid_0's l1: 0.0679671 [212] valid_0's l1: 0.0679664 [213] valid_0's l1: 0.067965 [214] valid_0's l1: 0.0679636 [215] valid_0's l1: 0.0679631 [216] valid_0's l1: 0.0679615 [217] valid_0's l1: 0.0679604 [218] valid_0's l1: 0.0679593 [219] valid_0's l1: 0.0679587 [220] valid_0's l1: 0.0679575 [221] valid_0's l1: 0.0679571 [222] valid_0's l1: 0.0679562 [223] valid_0's l1: 0.0679546 [224] valid_0's l1: 0.0679541 [225] valid_0's l1: 0.0679529 [226] valid_0's l1: 0.0679517 [227] valid_0's l1: 0.0679506 [228] valid_0's l1: 0.0679495 [229] valid_0's l1: 0.0679487 [230] valid_0's l1: 0.0679476 [231] valid_0's l1: 0.0679466 [232] valid_0's l1: 0.0679455 [233] valid_0's l1: 0.0679443 [234] valid_0's l1: 0.0679429 [235] valid_0's l1: 0.0679426 [236] valid_0's l1: 0.0679417 [237] valid_0's l1: 0.067941 [238] valid_0's l1: 0.0679398 [239] valid_0's l1: 0.0679386 [240] valid_0's l1: 0.0679374 [241] valid_0's l1: 0.0679364 [242] valid_0's l1: 0.067936 [243] valid_0's l1: 0.0679352 [244] valid_0's l1: 0.0679339 [245] valid_0's l1: 0.0679328 [246] valid_0's l1: 0.0679323 [247] valid_0's l1: 0.0679315 [248] valid_0's l1: 0.0679302 [249] valid_0's l1: 0.0679295 [250] valid_0's l1: 0.0679293 [251] valid_0's l1: 0.0679285 [252] valid_0's l1: 0.0679276 [253] valid_0's l1: 0.0679266 [254] valid_0's l1: 0.0679256 [255] valid_0's l1: 0.0679245 [256] valid_0's l1: 0.0679231 [257] valid_0's l1: 0.0679217 [258] valid_0's l1: 0.0679206 [259] valid_0's l1: 0.06792 [260] valid_0's l1: 0.0679191 [261] valid_0's l1: 0.067918 [262] valid_0's l1: 0.0679175 [263] valid_0's l1: 0.0679169 [264] valid_0's l1: 0.0679156 [265] valid_0's l1: 0.0679149 [266] valid_0's l1: 0.067914 [267] valid_0's l1: 0.0679134 [268] valid_0's l1: 0.0679128 [269] valid_0's l1: 0.0679111 [270] valid_0's l1: 0.0679098 [271] valid_0's l1: 0.0679091 [272] valid_0's l1: 0.067908 [273] valid_0's l1: 0.0679075 [274] valid_0's l1: 0.067907 [275] valid_0's l1: 0.0679059 [276] valid_0's l1: 0.0679047 [277] valid_0's l1: 0.0679036 [278] valid_0's l1: 0.0679025 [279] valid_0's l1: 0.0679011 [280] valid_0's l1: 0.0679006 [281] valid_0's l1: 0.0679003 [282] valid_0's l1: 0.067899 [283] valid_0's l1: 0.0678986 [284] valid_0's l1: 0.0678976 [285] valid_0's l1: 0.0678962 [286] valid_0's l1: 0.0678956 [287] valid_0's l1: 0.0678946 [288] valid_0's l1: 0.0678936 [289] valid_0's l1: 0.0678932 [290] valid_0's l1: 0.067892 [291] valid_0's l1: 0.067891 [292] valid_0's l1: 0.0678904 [293] valid_0's l1: 0.0678898 [294] valid_0's l1: 0.0678894 [295] valid_0's l1: 0.0678881 [296] valid_0's l1: 0.0678875 [297] valid_0's l1: 0.0678871 [298] valid_0's l1: 0.0678866 [299] valid_0's l1: 0.0678863 [300] valid_0's l1: 0.0678855 [301] valid_0's l1: 0.0678842 [302] valid_0's l1: 0.0678832 [303] valid_0's l1: 0.0678825 [304] valid_0's l1: 0.0678819 [305] valid_0's l1: 0.0678813 [306] valid_0's l1: 0.0678804 [307] valid_0's l1: 0.0678798 [308] valid_0's l1: 0.0678791 [309] valid_0's l1: 0.0678784 [310] valid_0's l1: 0.0678776 [311] valid_0's l1: 0.0678769 [312] valid_0's l1: 0.0678758 [313] valid_0's l1: 0.0678749 [314] valid_0's l1: 0.0678739 [315] valid_0's l1: 0.0678729 [316] valid_0's l1: 0.0678719 [317] valid_0's l1: 0.0678716 [318] valid_0's l1: 0.067871 [319] valid_0's l1: 0.0678702 [320] valid_0's l1: 0.0678695 [321] valid_0's l1: 0.0678693 [322] valid_0's l1: 0.0678688 [323] valid_0's l1: 0.0678677 [324] valid_0's l1: 0.0678674 [325] valid_0's l1: 0.0678671 [326] valid_0's l1: 0.0678669 [327] valid_0's l1: 0.0678659 [328] valid_0's l1: 0.0678649 [329] valid_0's l1: 0.0678643 [330] valid_0's l1: 0.0678633 [331] valid_0's l1: 0.0678626 [332] valid_0's l1: 0.067862 [333] valid_0's l1: 0.0678616 [334] valid_0's l1: 0.0678616 [335] valid_0's l1: 0.0678609 [336] valid_0's l1: 0.0678601 [337] valid_0's l1: 0.0678599 [338] valid_0's l1: 0.0678591 [339] valid_0's l1: 0.0678585 [340] valid_0's l1: 0.0678585 [341] valid_0's l1: 0.0678579 [342] valid_0's l1: 0.0678569 [343] valid_0's l1: 0.0678566 [344] valid_0's l1: 0.0678558 [345] valid_0's l1: 0.0678551 [346] valid_0's l1: 0.0678541 [347] valid_0's l1: 0.0678533 [348] valid_0's l1: 0.0678529 [349] valid_0's l1: 0.0678523 [350] valid_0's l1: 0.0678516 [351] valid_0's l1: 0.0678511 [352] valid_0's l1: 0.0678501 [353] valid_0's l1: 0.0678497 [354] valid_0's l1: 0.0678491 [355] valid_0's l1: 0.0678482 [356] valid_0's l1: 0.067848 [357] valid_0's l1: 0.0678474 [358] valid_0's l1: 0.0678465 [359] valid_0's l1: 0.0678462 [360] valid_0's l1: 0.0678458 [361] valid_0's l1: 0.0678452 [362] valid_0's l1: 0.0678444 [363] valid_0's l1: 0.067844 [364] valid_0's l1: 0.0678436 [365] valid_0's l1: 0.0678428 [366] valid_0's l1: 0.0678425 [367] valid_0's l1: 0.0678422 [368] valid_0's l1: 0.0678411 [369] valid_0's l1: 0.0678402 [370] valid_0's l1: 0.0678394 [371] valid_0's l1: 0.0678392 [372] valid_0's l1: 0.0678387 [373] valid_0's l1: 0.0678385 [374] valid_0's l1: 0.0678378 [375] valid_0's l1: 0.0678377 [376] valid_0's l1: 0.0678369 [377] valid_0's l1: 0.0678363 [378] valid_0's l1: 0.0678357 [379] valid_0's l1: 0.0678353 [380] valid_0's l1: 0.0678346 [381] valid_0's l1: 0.0678345 [382] valid_0's l1: 0.0678338 [383] valid_0's l1: 0.0678334 [384] valid_0's l1: 0.0678329 [385] valid_0's l1: 0.0678327 [386] valid_0's l1: 0.0678322 [387] valid_0's l1: 0.0678315 [388] valid_0's l1: 0.0678308 [389] valid_0's l1: 0.0678302 [390] valid_0's l1: 0.0678297 [391] valid_0's l1: 0.0678289 [392] valid_0's l1: 0.0678286 [393] valid_0's l1: 0.0678281 [394] valid_0's l1: 0.0678279 [395] valid_0's l1: 0.0678275 [396] valid_0's l1: 0.0678265 [397] valid_0's l1: 0.067826 [398] valid_0's l1: 0.0678253 [399] valid_0's l1: 0.0678248 [400] valid_0's l1: 0.0678241 [401] valid_0's l1: 0.0678235 [402] valid_0's l1: 0.0678233 [403] valid_0's l1: 0.0678231 [404] valid_0's l1: 0.0678229 [405] valid_0's l1: 0.0678224 [406] valid_0's l1: 0.0678222 [407] valid_0's l1: 0.0678221 [408] valid_0's l1: 0.0678213 [409] valid_0's l1: 0.0678209 [410] valid_0's l1: 0.0678207 [411] valid_0's l1: 0.0678202 [412] valid_0's l1: 0.0678197 [413] valid_0's l1: 0.0678194 [414] valid_0's l1: 0.0678192 [415] valid_0's l1: 0.0678191 [416] valid_0's l1: 0.0678187 [417] valid_0's l1: 0.0678185 [418] valid_0's l1: 0.0678183 [419] valid_0's l1: 0.0678181 [420] valid_0's l1: 0.0678178 [421] valid_0's l1: 0.0678173 [422] valid_0's l1: 0.0678168 [423] valid_0's l1: 0.067816 [424] valid_0's l1: 0.0678155 [425] valid_0's l1: 0.0678149 [426] valid_0's l1: 0.0678144 [427] valid_0's l1: 0.0678138 [428] valid_0's l1: 0.0678132 [429] valid_0's l1: 0.0678125 [430] valid_0's l1: 0.0678119 [431] valid_0's l1: 0.0678115 [432] valid_0's l1: 0.0678112 [433] valid_0's l1: 0.0678111 [434] valid_0's l1: 0.0678109 [435] valid_0's l1: 0.0678107 [436] valid_0's l1: 0.0678104 [437] valid_0's l1: 0.0678094 [438] valid_0's l1: 0.0678092 [439] valid_0's l1: 0.067809 [440] valid_0's l1: 0.0678083 [441] valid_0's l1: 0.0678081 [442] valid_0's l1: 0.0678078 [443] valid_0's l1: 0.0678076 [444] valid_0's l1: 0.0678072 [445] valid_0's l1: 0.067807 [446] valid_0's l1: 0.0678068 [447] valid_0's l1: 0.0678065 [448] valid_0's l1: 0.0678063 [449] valid_0's l1: 0.0678058 [450] valid_0's l1: 0.0678057 [451] valid_0's l1: 0.0678052 [452] valid_0's l1: 0.067805 [453] valid_0's l1: 0.0678048 [454] valid_0's l1: 0.0678044 [455] valid_0's l1: 0.067804 [456] valid_0's l1: 0.067804 [457] valid_0's l1: 0.0678037 [458] valid_0's l1: 0.0678036 [459] valid_0's l1: 0.0678033 [460] valid_0's l1: 0.0678029 [461] valid_0's l1: 0.0678023 [462] valid_0's l1: 0.0678023 [463] valid_0's l1: 0.067802 [464] valid_0's l1: 0.0678019 [465] valid_0's l1: 0.0678017 [466] valid_0's l1: 0.0678017 [467] valid_0's l1: 0.0678012 [468] valid_0's l1: 0.0678012 [469] valid_0's l1: 0.0678011 [470] valid_0's l1: 0.0678009 [471] valid_0's l1: 0.0678004 [472] valid_0's l1: 0.0678 [473] valid_0's l1: 0.0677998 [474] valid_0's l1: 0.0677996 [475] valid_0's l1: 0.0677992 [476] valid_0's l1: 0.0677992 [477] valid_0's l1: 0.067799 [478] valid_0's l1: 0.0677988 [479] valid_0's l1: 0.0677984 [480] valid_0's l1: 0.0677984 [481] valid_0's l1: 0.0677982 [482] valid_0's l1: 0.067798 [483] valid_0's l1: 0.0677978 [484] valid_0's l1: 0.0677977 [485] valid_0's l1: 0.0677977 [486] valid_0's l1: 0.0677973 [487] valid_0's l1: 0.0677966 [488] valid_0's l1: 0.0677965 [489] valid_0's l1: 0.0677965 [490] valid_0's l1: 0.0677964 [491] valid_0's l1: 0.0677958 [492] valid_0's l1: 0.0677958 [493] valid_0's l1: 0.0677956 [494] valid_0's l1: 0.0677955 [495] valid_0's l1: 0.0677953 [496] valid_0's l1: 0.0677951 [497] valid_0's l1: 0.0677947 [498] valid_0's l1: 0.0677941 [499] valid_0's l1: 0.0677939 [500] valid_0's l1: 0.0677933
from sklearn.metrics import mean_absolute_error
def get_score(x, y, clf, ii):
return mean_absolute_error(y, clf.predict(x, num_iteration=ii))
lab = []
val_gbdt = []
val_part = []
ii = int(0.7 * num_round)
while ii <= num_round:
lab.append(ii)
val_gbdt.append(get_score(xv, yv, clf_gbdt, ii))
val_part.append(get_score(xv, yv, clf_part, ii))
ii += 5
print(f'GBDT: {np.array(val_gbdt).min()}')
print(f'PART: {np.array(val_part).min()}')
GBDT: 0.06612165068883384 PART: 0.06612067704950389