import pandas as pd
train_df = pd.read_table('./dataset/train.tsv')
test_df = pd.read_table('./dataset/test.tsv')
train_df = pd.get_dummies(train_df)
test_df = pd.get_dummies(test_df)
import lightgbm as lgb
from sklearn.model_selection import train_test_split
y = train_df['mpg']
X_train, X_test, y_train, y_test = train_test_split(train_df, y,test_size=0.33, random_state=0)
X_train = X_train.drop('mpg',axis=1)
X_test = X_test.drop('mpg',axis=1)
X_test
id | cylinders | displacement | weight | acceleration | model year | origin | horsepower_100.0 | horsepower_105.0 | horsepower_108.0 | ... | car name_volkswagen model 111 | car name_volkswagen rabbit custom diesel | car name_volkswagen rabbit l | car name_volkswagen scirocco | car name_volkswagen type 3 | car name_volvo 144ea | car name_volvo 145e (sw) | car name_vw dasher (diesel) | car name_vw rabbit | car name_vw rabbit custom | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
18 | 40 | 4 | 141.0 | 3230.0 | 20.4 | 81 | 2 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
169 | 338 | 4 | 98.0 | 2135.0 | 16.6 | 78 | 3 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
106 | 211 | 8 | 318.0 | 4140.0 | 13.7 | 77 | 1 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
92 | 188 | 4 | 122.0 | 2300.0 | 15.5 | 77 | 1 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
176 | 351 | 4 | 134.0 | 2711.0 | 15.5 | 80 | 3 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
183 | 366 | 6 | 231.0 | 3445.0 | 13.4 | 78 | 1 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
5 | 16 | 4 | 151.0 | 2735.0 | 18.0 | 82 | 1 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
139 | 276 | 8 | 262.0 | 3221.0 | 13.5 | 75 | 1 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
12 | 31 | 4 | 97.0 | 2130.0 | 14.5 | 70 | 3 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
160 | 318 | 4 | 135.0 | 2295.0 | 11.6 | 82 | 1 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
61 | 123 | 4 | 121.0 | 2600.0 | 12.8 | 77 | 2 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
124 | 252 | 4 | 97.0 | 2254.0 | 23.5 | 72 | 2 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 |
164 | 329 | 6 | 225.0 | 3620.0 | 18.7 | 78 | 1 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
145 | 288 | 4 | 85.0 | 1835.0 | 17.3 | 80 | 2 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
80 | 166 | 4 | 97.0 | 1834.0 | 19.0 | 71 | 2 | 0 | 0 | 0 | ... | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
7 | 19 | 8 | 350.0 | 3664.0 | 11.0 | 73 | 1 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
33 | 71 | 4 | 79.0 | 1755.0 | 16.9 | 81 | 3 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
129 | 262 | 6 | 173.0 | 2725.0 | 12.6 | 81 | 1 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
37 | 79 | 4 | 97.0 | 2145.0 | 18.0 | 80 | 3 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
74 | 152 | 6 | 181.0 | 2945.0 | 16.4 | 82 | 1 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
159 | 316 | 6 | 250.0 | 3139.0 | 14.5 | 71 | 1 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
144 | 285 | 4 | 97.0 | 1835.0 | 20.5 | 70 | 2 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
45 | 91 | 4 | 140.0 | 2639.0 | 17.0 | 75 | 1 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
158 | 315 | 4 | 83.0 | 2003.0 | 19.0 | 74 | 3 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
60 | 121 | 4 | 90.0 | 2223.0 | 16.5 | 75 | 2 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
122 | 249 | 4 | 105.0 | 2190.0 | 14.2 | 81 | 2 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
178 | 354 | 8 | 318.0 | 3735.0 | 13.2 | 78 | 1 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
184 | 367 | 8 | 400.0 | 4997.0 | 14.0 | 73 | 1 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
97 | 197 | 4 | 97.0 | 2190.0 | 14.1 | 77 | 2 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
44 | 88 | 4 | 122.0 | 2395.0 | 16.0 | 72 | 1 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
135 | 270 | 5 | 121.0 | 2950.0 | 19.9 | 80 | 2 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
4 | 13 | 4 | 89.0 | 2050.0 | 17.3 | 81 | 3 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
137 | 273 | 4 | 156.0 | 2620.0 | 14.4 | 81 | 1 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
193 | 383 | 6 | 232.0 | 2634.0 | 13.0 | 71 | 1 | 1 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
153 | 306 | 8 | 350.0 | 4502.0 | 13.5 | 72 | 1 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
66 | 130 | 4 | 140.0 | 2408.0 | 19.0 | 71 | 1 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
26 | 53 | 8 | 400.0 | 5140.0 | 12.0 | 71 | 1 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
131 | 265 | 4 | 79.0 | 2000.0 | 16.0 | 74 | 2 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
146 | 290 | 4 | 151.0 | 2678.0 | 16.5 | 80 | 1 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
63 | 127 | 8 | 351.0 | 4657.0 | 13.5 | 75 | 1 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
8 | 23 | 4 | 121.0 | 2234.0 | 12.5 | 70 | 2 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
75 | 157 | 8 | 429.0 | 4341.0 | 10.0 | 70 | 1 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
98 | 201 | 4 | 105.0 | 2150.0 | 14.9 | 79 | 1 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
166 | 335 | 4 | 140.0 | 2565.0 | 13.6 | 76 | 1 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
71 | 148 | 4 | 116.0 | 2123.0 | 14.0 | 71 | 2 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
196 | 387 | 8 | 304.0 | 3433.0 | 12.0 | 70 | 1 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
86 | 180 | 4 | 79.0 | 1963.0 | 15.5 | 74 | 2 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
96 | 195 | 4 | 120.0 | 2489.0 | 15.0 | 74 | 3 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
149 | 297 | 6 | 225.0 | 3233.0 | 15.4 | 76 | 1 | 1 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
24 | 48 | 4 | 151.0 | 3035.0 | 20.5 | 82 | 1 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
30 | 67 | 4 | 98.0 | 2045.0 | 18.5 | 77 | 3 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
113 | 234 | 4 | 100.0 | 2320.0 | 15.8 | 81 | 2 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
40 | 83 | 4 | 90.0 | 1985.0 | 21.5 | 78 | 2 | 0 | 0 | 0 | ... | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
56 | 115 | 8 | 360.0 | 3821.0 | 11.0 | 73 | 1 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
130 | 264 | 8 | 302.0 | 3205.0 | 11.2 | 78 | 1 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
182 | 364 | 8 | 305.0 | 3880.0 | 12.5 | 77 | 1 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
180 | 359 | 4 | 72.0 | 1613.0 | 18.0 | 71 | 3 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
19 | 41 | 6 | 168.0 | 2910.0 | 11.4 | 80 | 3 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
152 | 305 | 4 | 140.0 | 2755.0 | 15.8 | 77 | 1 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
136 | 272 | 6 | 156.0 | 2807.0 | 13.5 | 73 | 3 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
66 rows × 245 columns
import numpy as np
lgb_train = lgb.Dataset(X_train, y_train)
lgb_eval = lgb.Dataset(X_test, y_test, reference=lgb_train)
# LightGBM parameters
params = {
'task' : 'train',
'boosting_type' : 'gbdt',
'objective' : 'regression',
'metric' : {'l2'},
'num_leaves' : 31,
'learning_rate' : 0.1,
'feature_fraction' : 0.9,
'bagging_fraction' : 0.8,
'bagging_freq': 5,
'verbose' : 0
}
# train
gbm = lgb.train(params,
lgb_train,
num_boost_round=100,
valid_sets=lgb_eval,
early_stopping_rounds=10)
y_pred = gbm.predict(test_df, num_iteration=gbm.best_iteration)
[1] valid_0's l2: 49.1661 Training until validation scores don't improve for 10 rounds. [2] valid_0's l2: 43.31 [3] valid_0's l2: 38.9223 [4] valid_0's l2: 35.4825 [5] valid_0's l2: 31.4631 [6] valid_0's l2: 28.1653 [7] valid_0's l2: 25.473 [8] valid_0's l2: 23.0791 [9] valid_0's l2: 21.3668 [10] valid_0's l2: 19.7726 [11] valid_0's l2: 18.1561 [12] valid_0's l2: 16.8996 [13] valid_0's l2: 15.913 [14] valid_0's l2: 15.2111 [15] valid_0's l2: 14.6074 [16] valid_0's l2: 13.7934 [17] valid_0's l2: 13.2737 [18] valid_0's l2: 12.8726 [19] valid_0's l2: 12.6126 [20] valid_0's l2: 12.1633 [21] valid_0's l2: 11.8965 [22] valid_0's l2: 11.6803 [23] valid_0's l2: 11.5124 [24] valid_0's l2: 11.3882 [25] valid_0's l2: 11.1853 [26] valid_0's l2: 10.9705 [27] valid_0's l2: 10.8323 [28] valid_0's l2: 10.6909 [29] valid_0's l2: 10.6382 [30] valid_0's l2: 10.5787 [31] valid_0's l2: 10.5342 [32] valid_0's l2: 10.4999 [33] valid_0's l2: 10.5216 [34] valid_0's l2: 10.5298 [35] valid_0's l2: 10.5578 [36] valid_0's l2: 10.4628 [37] valid_0's l2: 10.4111 [38] valid_0's l2: 10.3446 [39] valid_0's l2: 10.3691 [40] valid_0's l2: 10.3154 [41] valid_0's l2: 10.3383 [42] valid_0's l2: 10.3248 [43] valid_0's l2: 10.3402 [44] valid_0's l2: 10.3548 [45] valid_0's l2: 10.334 [46] valid_0's l2: 10.2862 [47] valid_0's l2: 10.3016 [48] valid_0's l2: 10.3147 [49] valid_0's l2: 10.3306 [50] valid_0's l2: 10.3417 [51] valid_0's l2: 10.262 [52] valid_0's l2: 10.197 [53] valid_0's l2: 10.1877 [54] valid_0's l2: 10.203 [55] valid_0's l2: 10.0974 [56] valid_0's l2: 10.1084 [57] valid_0's l2: 10.1324 [58] valid_0's l2: 10.1469 [59] valid_0's l2: 10.0992 [60] valid_0's l2: 10.1153 [61] valid_0's l2: 10.1255 [62] valid_0's l2: 10.1098 [63] valid_0's l2: 10.1145 [64] valid_0's l2: 10.1067 [65] valid_0's l2: 10.1194 Early stopping, best iteration is: [55] valid_0's l2: 10.0974
y_pred
array([26.52835238, 18.49803738, 32.67181357, 16.31977399, 28.34038704, 24.25271276, 15.23983205, 19.04545687, 22.31722562, 20.63312642, 13.35279956, 29.17365419, 13.35279956, 35.20868991, 34.11221742, 15.44318243, 23.70059147, 22.5022716 , 12.80555576, 29.27162409, 26.73482457, 36.80466453, 16.04857589, 13.35279956, 18.11987521, 13.35279956, 13.96480581, 20.07013574, 31.43048483, 29.6024268 , 31.71756636, 17.43539101, 36.5433746 , 13.53353208, 17.44598433, 33.16601043, 31.64783442, 18.94879111, 26.40155324, 36.34432887, 13.53353208, 29.46386951, 21.78137328, 28.80689612, 20.10981617, 36.80466453, 36.39168504, 18.94779047, 22.24534944, 27.37623472, 22.38475853, 27.36331438, 32.06239949, 17.58270851, 27.35080543, 15.23983205, 15.99446218, 30.49175808, 13.35279956, 15.20665478, 15.81372966, 28.39681341, 31.9915289 , 15.81372966, 23.75903806, 13.53353208, 36.28536236, 26.96715497, 18.29865123, 13.55614994, 18.92494464, 34.39809039, 35.29991474, 13.35279956, 32.47827898, 29.11036212, 17.19473263, 13.53353208, 13.53353208, 20.20590641, 30.20598152, 32.0724365 , 23.9391743 , 23.87371027, 22.10088915, 36.28536236, 18.88601326, 18.1910494 , 13.35279956, 13.35279956, 23.37453678, 32.9725995 , 15.27583175, 21.09623432, 16.65891689, 15.23342198, 35.90821676, 27.95908648, 13.49112231, 23.98837612, 28.80061185, 28.06170948, 18.6303405 , 16.55563817, 25.06472211, 31.29913793, 21.53847536, 15.27604631, 16.74117753, 17.03072793, 13.38901382, 24.47944966, 27.78378752, 28.05540311, 13.38901382, 28.59609612, 13.56974634, 15.31204601, 20.19033798, 19.5803554 , 37.04173672, 15.55386905, 18.49447448, 20.46574499, 18.49447448, 19.05288372, 20.27938902, 19.4048599 , 15.11914287, 36.58140106, 31.57576059, 35.64954444, 21.46214322, 19.53110383, 15.7469124 , 24.42914623, 21.42287602, 24.89782518, 35.83402956, 27.91547282, 29.31669131, 13.74828809, 32.14135493, 16.08235759, 14.7705975 , 29.21818079, 20.46838331, 29.86918275, 26.86171747, 18.57627399, 21.32438674, 16.76118959, 24.24910896, 23.69933402, 25.50018735, 29.22560549, 15.30959382, 29.6721998 , 15.30959382, 13.74828809, 36.86831198, 19.95223354, 29.6721998 , 25.8151922 , 13.74828809, 18.42143006, 18.60845019, 22.54429044, 19.97064504, 13.62142749, 13.62142749, 37.40735977, 19.73677715, 16.08235759, 23.09598025, 25.58174245, 18.0334027 , 28.51807315, 33.08539779, 13.74828809, 18.48535562, 20.54148264, 29.6721998 , 13.74828809, 23.53349758, 14.7705975 , 36.56840796, 24.70485103, 26.23972761, 21.73480769, 37.40735977, 25.78912758, 16.73884635, 25.2170913 , 24.69412433, 28.0033536 , 30.27401997, 15.58741655, 30.94077257])
submission = pd.DataFrame({ 'id': test_df['id'],
'mpg': y_pred })
submission.to_csv("submission.csv", index=False,header=0)