import os
import numpy as np
import pandas as pd
home_folder = "."
data_folder = os.path.join(home_folder, "data")
data_filename = os.path.join(data_folder, "leagues_NBA_2014_games_games.csv")
results = pd.read_csv(data_filename)
results.iloc[:5]
Date | Score Type | Visitor Team | VisitorPts | Home Team | HomePts | OT? | Notes | |
---|---|---|---|---|---|---|---|---|
0 | Tue Oct 29 2013 | Box Score | Orlando Magic | 87 | Indiana Pacers | 97 | NaN | NaN |
1 | Tue Oct 29 2013 | Box Score | Los Angeles Clippers | 103 | Los Angeles Lakers | 116 | NaN | NaN |
2 | Tue Oct 29 2013 | Box Score | Chicago Bulls | 95 | Miami Heat | 107 | NaN | NaN |
3 | Wed Oct 30 2013 | Box Score | Brooklyn Nets | 94 | Cleveland Cavaliers | 98 | NaN | NaN |
4 | Wed Oct 30 2013 | Box Score | Atlanta Hawks | 109 | Dallas Mavericks | 118 | NaN | NaN |
# Don't read the first row, as it is blank, and parse the date column as a date
results = pd.read_csv(data_filename, skiprows=[0,])
# Fix the name of the columns
results.columns = ["Date", "Score Type", "Visitor Team", "VisitorPts", "Home Team", "HomePts", "OT?", "Notes"]
results.iloc[:5]
Date | Score Type | Visitor Team | VisitorPts | Home Team | HomePts | OT? | Notes | |
---|---|---|---|---|---|---|---|---|
0 | Tue Oct 29 2013 | Box Score | Los Angeles Clippers | 103 | Los Angeles Lakers | 116 | NaN | NaN |
1 | Tue Oct 29 2013 | Box Score | Chicago Bulls | 95 | Miami Heat | 107 | NaN | NaN |
2 | Wed Oct 30 2013 | Box Score | Brooklyn Nets | 94 | Cleveland Cavaliers | 98 | NaN | NaN |
3 | Wed Oct 30 2013 | Box Score | Atlanta Hawks | 109 | Dallas Mavericks | 118 | NaN | NaN |
4 | Wed Oct 30 2013 | Box Score | Washington Wizards | 102 | Detroit Pistons | 113 | NaN | NaN |
results["HomeWin"] = results["VisitorPts"] < results["HomePts"]
# Our "class values"
y_true = results["HomeWin"].values
results.iloc[:5]
Date | Score Type | Visitor Team | VisitorPts | Home Team | HomePts | OT? | Notes | HomeWin | |
---|---|---|---|---|---|---|---|---|---|
0 | Tue Oct 29 2013 | Box Score | Los Angeles Clippers | 103 | Los Angeles Lakers | 116 | NaN | NaN | True |
1 | Tue Oct 29 2013 | Box Score | Chicago Bulls | 95 | Miami Heat | 107 | NaN | NaN | True |
2 | Wed Oct 30 2013 | Box Score | Brooklyn Nets | 94 | Cleveland Cavaliers | 98 | NaN | NaN | True |
3 | Wed Oct 30 2013 | Box Score | Atlanta Hawks | 109 | Dallas Mavericks | 118 | NaN | NaN | True |
4 | Wed Oct 30 2013 | Box Score | Washington Wizards | 102 | Detroit Pistons | 113 | NaN | NaN | True |
print("Home Win 百分比: {0:.1f}%".format(100 * results["HomeWin"].sum() / results["HomeWin"].count()))
results["HomeLastWin"] = False
results["VisitorLastWin"] = False
# This creates two new columns, all set to False
results.iloc[:5]
Home Win 百分比: 58.0%
Date | Score Type | Visitor Team | VisitorPts | Home Team | HomePts | OT? | Notes | HomeWin | HomeLastWin | VisitorLastWin | |
---|---|---|---|---|---|---|---|---|---|---|---|
0 | Tue Oct 29 2013 | Box Score | Los Angeles Clippers | 103 | Los Angeles Lakers | 116 | NaN | NaN | True | False | False |
1 | Tue Oct 29 2013 | Box Score | Chicago Bulls | 95 | Miami Heat | 107 | NaN | NaN | True | False | False |
2 | Wed Oct 30 2013 | Box Score | Brooklyn Nets | 94 | Cleveland Cavaliers | 98 | NaN | NaN | True | False | False |
3 | Wed Oct 30 2013 | Box Score | Atlanta Hawks | 109 | Dallas Mavericks | 118 | NaN | NaN | True | False | False |
4 | Wed Oct 30 2013 | Box Score | Washington Wizards | 102 | Detroit Pistons | 113 | NaN | NaN | True | False | False |
# Now compute the actual values for these
# Did the home and visitor teams win their last game?
from collections import defaultdict
won_last = defaultdict(int)
for index, row in results.iterrows(): # Note that this is not efficient
home_team = row["Home Team"]
visitor_team = row["Visitor Team"]
row["HomeLastWin"] = won_last[home_team]
row["VisitorLastWin"] = won_last[visitor_team]
results.iloc[index] = row
# Set current win
won_last[home_team] = row["HomeWin"]
won_last[visitor_team] = not row["HomeWin"]
results.iloc[20:25]
Date | Score Type | Visitor Team | VisitorPts | Home Team | HomePts | OT? | Notes | HomeWin | HomeLastWin | VisitorLastWin | |
---|---|---|---|---|---|---|---|---|---|---|---|
20 | Fri Nov 1 2013 | Box Score | Miami Heat | 100 | Brooklyn Nets | 101 | NaN | NaN | True | False | False |
21 | Fri Nov 1 2013 | Box Score | Cleveland Cavaliers | 84 | Charlotte Bobcats | 90 | NaN | NaN | True | False | True |
22 | Fri Nov 1 2013 | Box Score | Portland Trail Blazers | 113 | Denver Nuggets | 98 | NaN | NaN | False | False | False |
23 | Fri Nov 1 2013 | Box Score | Dallas Mavericks | 105 | Houston Rockets | 113 | NaN | NaN | True | True | True |
24 | Fri Nov 1 2013 | Box Score | San Antonio Spurs | 91 | Los Angeles Lakers | 85 | NaN | NaN | False | False | True |
from sklearn.tree import DecisionTreeClassifier
clf = DecisionTreeClassifier(random_state=14)
from sklearn.cross_validation import cross_val_score
# Create a dataset with just the neccessary information
X_previouswins = results[["HomeLastWin", "VisitorLastWin"]].values
clf = DecisionTreeClassifier(random_state=14)
scores = cross_val_score(clf, X_previouswins, y_true, scoring='accuracy')
print("Using just the last result from the home and visitor teams")
print("Accuracy: {0:.1f}%".format(np.mean(scores) * 100))
Using just the last result from the home and visitor teams Accuracy: 57.7%
/home/dlinking-lxy/more-space/pyworks/venv/lib/python3.5/site-packages/sklearn/cross_validation.py:44: DeprecationWarning: This module was deprecated in version 0.18 in favor of the model_selection module into which all the refactored classes and functions are moved. Also note that the interface of the new CV iterators are different from that of this module. This module will be removed in 0.20. "This module will be removed in 0.20.", DeprecationWarning)
# What about win streaks?
results["HomeWinStreak"] = 0
results["VisitorWinStreak"] = 0
# Did the home and visitor teams win their last game?
from collections import defaultdict
win_streak = defaultdict(int)
for index, row in results.iterrows(): # Note that this is not efficient
home_team = row["Home Team"]
visitor_team = row["Visitor Team"]
row["HomeWinStreak"] = win_streak[home_team]
row["VisitorWinStreak"] = win_streak[visitor_team]
results.loc[index] = row
# Set current win
if row["HomeWin"]:
win_streak[home_team] += 1
win_streak[visitor_team] = 0
else:
win_streak[home_team] = 0
win_streak[visitor_team] += 1
clf = DecisionTreeClassifier(random_state=14)
X_winstreak = results[["HomeLastWin", "VisitorLastWin", "HomeWinStreak", "VisitorWinStreak"]].values
scores = cross_val_score(clf, X_winstreak, y_true, scoring='accuracy')
print("Using whether the home team is ranked higher")
print("Accuracy: {0:.1f}%".format(np.mean(scores) * 100))
Using whether the home team is ranked higher Accuracy: 56.2%
# Let's try see which team is better on the ladder. Using the previous year's ladder
ladder_filename = os.path.join(data_folder, "leagues_NBA_2013_standings_expanded-standings.csv")
ladder = pd.read_csv(ladder_filename)
ladder
Rk | Team | Overall | Home | Road | E | W | A | C | SE | ... | Post | ≤3 | ≥10 | Oct | Nov | Dec | Jan | Feb | Mar | Apr | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 1 | Miami Heat | 66-16 | 37-4 | 29-12 | 41-11 | 25-5 | 14-4 | 12-6 | 15-1 | ... | 30-2 | 9-3 | 39-8 | 1-0 | 10-3 | 10-5 | 8-5 | 12-1 | 17-1 | 8-1 |
1 | 2 | Oklahoma City Thunder | 60-22 | 34-7 | 26-15 | 21-9 | 39-13 | 7-3 | 8-2 | 6-4 | ... | 21-8 | 3-6 | 44-6 | NaN | 13-4 | 11-2 | 11-5 | 7-4 | 12-5 | 6-2 |
2 | 3 | San Antonio Spurs | 58-24 | 35-6 | 23-18 | 25-5 | 33-19 | 8-2 | 9-1 | 8-2 | ... | 16-12 | 9-5 | 31-10 | 1-0 | 12-4 | 12-4 | 12-3 | 8-3 | 10-4 | 3-6 |
3 | 4 | Denver Nuggets | 57-25 | 38-3 | 19-22 | 19-11 | 38-14 | 5-5 | 10-0 | 4-6 | ... | 24-4 | 11-7 | 28-8 | 0-1 | 8-8 | 9-6 | 12-3 | 8-4 | 13-2 | 7-1 |
4 | 5 | Los Angeles Clippers | 56-26 | 32-9 | 24-17 | 21-9 | 35-17 | 7-3 | 8-2 | 6-4 | ... | 17-9 | 3-5 | 38-12 | 1-0 | 8-6 | 16-0 | 9-7 | 8-5 | 7-7 | 7-1 |
5 | 6 | Memphis Grizzlies | 56-26 | 32-9 | 24-17 | 22-8 | 34-18 | 8-2 | 8-2 | 6-4 | ... | 23-8 | 6-4 | 28-9 | 0-1 | 12-1 | 7-7 | 10-7 | 9-2 | 11-6 | 7-2 |
6 | 7 | New York Knicks | 54-28 | 31-10 | 23-18 | 37-15 | 17-13 | 10-6 | 12-6 | 15-3 | ... | 22-10 | 7-5 | 31-12 | NaN | 11-4 | 10-5 | 7-6 | 6-5 | 12-6 | 8-2 |
7 | 8 | Brooklyn Nets | 49-33 | 26-15 | 23-18 | 36-16 | 13-17 | 11-5 | 13-5 | 12-6 | ... | 18-11 | 9-4 | 23-17 | NaN | 11-4 | 5-11 | 11-4 | 7-5 | 8-7 | 7-2 |
8 | 9 | Indiana Pacers | 49-32 | 30-11 | 19-21 | 31-20 | 18-12 | 6-11 | 13-3 | 12-6 | ... | 17-11 | 4-9 | 27-14 | 1-0 | 7-8 | 10-5 | 9-6 | 9-3 | 11-5 | 2-5 |
9 | 10 | Golden State Warriors | 47-35 | 28-13 | 19-22 | 19-11 | 28-24 | 7-3 | 5-5 | 7-3 | ... | 17-13 | 5-3 | 20-18 | 1-0 | 8-6 | 12-4 | 8-7 | 4-8 | 9-7 | 5-3 |
10 | 11 | Chicago Bulls | 45-37 | 24-17 | 21-20 | 34-18 | 11-19 | 13-5 | 9-7 | 12-6 | ... | 15-15 | 11-7 | 16-16 | 1-0 | 6-7 | 9-6 | 12-4 | 5-8 | 7-7 | 5-5 |
11 | 12 | Houston Rockets | 45-37 | 29-12 | 16-25 | 21-9 | 24-28 | 7-3 | 7-3 | 7-3 | ... | 16-11 | 5-5 | 26-13 | 1-0 | 6-8 | 10-6 | 8-9 | 6-5 | 9-5 | 5-4 |
12 | 13 | Los Angeles Lakers | 45-37 | 29-12 | 16-25 | 17-13 | 28-24 | 6-4 | 6-4 | 5-5 | ... | 20-8 | 8-5 | 18-17 | 0-2 | 8-6 | 7-7 | 5-11 | 9-4 | 9-6 | 7-1 |
13 | 14 | Atlanta Hawks | 44-38 | 25-16 | 19-22 | 29-23 | 15-15 | 7-11 | 11-7 | 11-5 | ... | 15-16 | 5-5 | 19-20 | NaN | 9-5 | 10-5 | 7-9 | 7-4 | 8-10 | 3-5 |
14 | 15 | Utah Jazz | 43-39 | 30-11 | 13-28 | 17-13 | 26-26 | 5-5 | 5-5 | 7-3 | ... | 13-15 | 5-7 | 19-21 | 1-0 | 8-8 | 6-9 | 10-4 | 6-6 | 7-9 | 5-3 |
15 | 16 | Boston Celtics | 41-40 | 27-13 | 14-27 | 27-24 | 14-16 | 7-9 | 8-9 | 12-6 | ... | 13-16 | 8-7 | 18-23 | 0-1 | 9-6 | 5-9 | 8-7 | 8-4 | 8-8 | 3-5 |
16 | 17 | Dallas Mavericks | 41-41 | 24-17 | 17-24 | 17-13 | 24-28 | 5-5 | 6-4 | 6-4 | ... | 18-12 | 5-8 | 17-19 | 1-1 | 6-8 | 5-10 | 7-8 | 6-5 | 11-5 | 5-4 |
17 | 18 | Milwaukee Bucks | 38-44 | 21-20 | 17-24 | 24-28 | 14-16 | 11-7 | 7-9 | 6-12 | ... | 12-19 | 7-5 | 13-25 | NaN | 7-7 | 9-6 | 8-7 | 4-8 | 7-9 | 3-7 |
18 | 19 | Philadelphia 76ers | 34-48 | 23-18 | 11-30 | 22-30 | 12-18 | 7-9 | 7-11 | 8-10 | ... | 12-19 | 4-5 | 13-24 | 1-0 | 9-6 | 4-11 | 5-9 | 3-8 | 8-9 | 4-5 |
19 | 20 | Toronto Raptors | 34-48 | 21-20 | 13-28 | 22-30 | 12-18 | 5-11 | 8-10 | 9-9 | ... | 13-16 | 8-8 | 16-22 | 0-1 | 4-12 | 7-7 | 5-10 | 7-5 | 4-11 | 7-2 |
20 | 21 | Portland Trail Blazers | 33-49 | 22-19 | 11-30 | 15-15 | 18-34 | 5-5 | 5-5 | 5-5 | ... | 8-21 | 9-6 | 13-24 | 1-0 | 5-10 | 9-4 | 8-8 | 3-9 | 7-9 | 0-9 |
21 | 22 | Minnesota Timberwolves | 31-51 | 20-21 | 11-30 | 14-16 | 17-35 | 4-6 | 7-3 | 3-7 | ... | 12-20 | 3-10 | 15-26 | NaN | 7-8 | 7-5 | 3-12 | 3-10 | 6-11 | 5-5 |
22 | 23 | Detroit Pistons | 29-53 | 18-23 | 11-30 | 25-27 | 4-26 | 6-12 | 8-8 | 11-7 | ... | 8-20 | 6-10 | 15-29 | 0-1 | 5-11 | 6-10 | 6-7 | 6-8 | 1-13 | 5-3 |
23 | 24 | Washington Wizards | 29-53 | 22-19 | 7-34 | 15-37 | 14-16 | 5-13 | 5-13 | 5-11 | ... | 14-17 | 6-9 | 13-17 | 0-1 | 1-12 | 3-11 | 7-9 | 7-5 | 9-8 | 2-7 |
24 | 25 | Sacramento Kings | 28-54 | 20-21 | 8-33 | 14-16 | 14-38 | 4-6 | 4-6 | 6-4 | ... | 9-19 | 7-3 | 12-31 | 0-1 | 4-10 | 7-8 | 6-11 | 3-9 | 7-8 | 1-7 |
25 | 26 | New Orleans Hornets | 27-55 | 16-25 | 11-30 | 12-18 | 15-37 | 3-7 | 5-5 | 4-6 | ... | 8-21 | 7-5 | 10-27 | 0-1 | 4-9 | 3-13 | 8-8 | 5-8 | 6-9 | 1-7 |
26 | 27 | Phoenix Suns | 25-57 | 17-24 | 8-33 | 8-22 | 17-35 | 1-9 | 4-6 | 3-7 | ... | 8-21 | 6-8 | 10-31 | 0-1 | 7-9 | 4-11 | 5-9 | 4-9 | 3-12 | 2-6 |
27 | 28 | Cleveland Cavaliers | 24-58 | 14-27 | 10-31 | 18-34 | 6-24 | 5-13 | 3-13 | 10-8 | ... | 8-21 | 6-11 | 7-27 | 1-0 | 3-12 | 3-13 | 6-8 | 7-5 | 2-12 | 2-8 |
28 | 29 | Charlotte Bobcats | 21-61 | 15-26 | 6-35 | 18-34 | 3-27 | 6-12 | 6-12 | 6-10 | ... | 9-21 | 6-6 | 6-37 | NaN | 7-8 | 1-15 | 3-11 | 2-10 | 4-12 | 4-5 |
29 | 30 | Orlando Magic | 20-62 | 12-29 | 8-33 | 10-42 | 10-20 | 2-16 | 5-13 | 3-13 | ... | 5-25 | 2-9 | 8-30 | NaN | 5-10 | 7-9 | 2-12 | 2-11 | 3-13 | 1-7 |
30 rows × 24 columns
# We can create a new feature -- HomeTeamRanksHigher\
results["HomeTeamRanksHigher"] = 0
for index, row in results.iterrows():
home_team = row["Home Team"]
visitor_team = row["Visitor Team"]
if home_team == "New Orleans Pelicans":
home_team = "New Orleans Hornets"
elif visitor_team == "New Orleans Pelicans":
visitor_team = "New Orleans Hornets"
home_rank = ladder[ladder["Team"] == home_team]["Rk"].values[0]
visitor_rank = ladder[ladder["Team"] == visitor_team]["Rk"].values[0]
row["HomeTeamRanksHigher"] = int(home_rank > visitor_rank)
results.iloc[index] = row
results[:5]
Date | Score Type | Visitor Team | VisitorPts | Home Team | HomePts | OT? | Notes | HomeWin | HomeLastWin | VisitorLastWin | HomeWinStreak | VisitorWinStreak | HomeTeamRanksHigher | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | Tue Oct 29 2013 | Box Score | Los Angeles Clippers | 103 | Los Angeles Lakers | 116 | NaN | NaN | True | False | False | 0 | 0 | 1 |
1 | Tue Oct 29 2013 | Box Score | Chicago Bulls | 95 | Miami Heat | 107 | NaN | NaN | True | False | False | 0 | 0 | 0 |
2 | Wed Oct 30 2013 | Box Score | Brooklyn Nets | 94 | Cleveland Cavaliers | 98 | NaN | NaN | True | False | False | 0 | 0 | 1 |
3 | Wed Oct 30 2013 | Box Score | Atlanta Hawks | 109 | Dallas Mavericks | 118 | NaN | NaN | True | False | False | 0 | 0 | 1 |
4 | Wed Oct 30 2013 | Box Score | Washington Wizards | 102 | Detroit Pistons | 113 | NaN | NaN | True | False | False | 0 | 0 | 0 |
X_homehigher = results[["HomeLastWin", "VisitorLastWin", "HomeTeamRanksHigher"]].values
clf = DecisionTreeClassifier(random_state=14)
scores = cross_val_score(clf, X_homehigher, y_true, scoring='accuracy')
print("Using whether the home team is ranked higher")
print("准确率: {0:.1f}%".format(np.mean(scores) * 100))
Using whether the home team is ranked higher 准确率: 60.2%
from sklearn.grid_search import GridSearchCV
parameter_space = {
"max_depth": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20],
}
clf = DecisionTreeClassifier(random_state=14)
grid = GridSearchCV(clf, parameter_space)
grid.fit(X_homehigher, y_true)
print("准确率: {0:.1f}%".format(grid.best_score_ * 100))
准确率: 60.5%
# Who won the last match? We ignore home/visitor for this bit
last_match_winner = defaultdict(int)
results["HomeTeamWonLast"] = 0
for index, row in results.iterrows():
home_team = row["Home Team"]
visitor_team = row["Visitor Team"]
teams = tuple(sorted([home_team, visitor_team])) # Sort for a consistent ordering
# Set in the row, who won the last encounter
row["HomeTeamWonLast"] = 1 if last_match_winner[teams] == row["Home Team"] else 0
results.ix[index] = row
# Who won this one?
winner = row["Home Team"] if row["HomeWin"] else row["Visitor Team"]
last_match_winner[teams] = winner
results.ix[:5]
Date | Score Type | Visitor Team | VisitorPts | Home Team | HomePts | OT? | Notes | HomeWin | HomeLastWin | VisitorLastWin | HomeWinStreak | VisitorWinStreak | HomeTeamRanksHigher | HomeTeamWonLast | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | Tue Oct 29 2013 | Box Score | Los Angeles Clippers | 103 | Los Angeles Lakers | 116 | NaN | NaN | True | False | False | 0 | 0 | 1 | 0 |
1 | Tue Oct 29 2013 | Box Score | Chicago Bulls | 95 | Miami Heat | 107 | NaN | NaN | True | False | False | 0 | 0 | 0 | 0 |
2 | Wed Oct 30 2013 | Box Score | Brooklyn Nets | 94 | Cleveland Cavaliers | 98 | NaN | NaN | True | False | False | 0 | 0 | 1 | 0 |
3 | Wed Oct 30 2013 | Box Score | Atlanta Hawks | 109 | Dallas Mavericks | 118 | NaN | NaN | True | False | False | 0 | 0 | 1 | 0 |
4 | Wed Oct 30 2013 | Box Score | Washington Wizards | 102 | Detroit Pistons | 113 | NaN | NaN | True | False | False | 0 | 0 | 0 | 0 |
5 | Wed Oct 30 2013 | Box Score | Los Angeles Lakers | 94 | Golden State Warriors | 125 | NaN | NaN | True | False | True | 0 | 1 | 0 | 0 |
X_home_higher = results[["HomeTeamRanksHigher", "HomeTeamWonLast"]].values
clf = DecisionTreeClassifier(random_state=14)
scores = cross_val_score(clf, X_home_higher, y_true, scoring='accuracy')
print("Using whether the home team is ranked higher")
print("准确率: {0:.1f}%".format(np.mean(scores) * 100))
Using whether the home team is ranked higher 准确率: 60.5%
from sklearn.preprocessing import LabelEncoder, OneHotEncoder
encoding = LabelEncoder()
encoding.fit(results["Home Team"].values)
home_teams = encoding.transform(results["Home Team"].values)
visitor_teams = encoding.transform(results["Visitor Team"].values)
X_teams = np.vstack([home_teams, visitor_teams]).T
onehot = OneHotEncoder()
X_teams = onehot.fit_transform(X_teams).todense()
clf = DecisionTreeClassifier(random_state=14)
scores = cross_val_score(clf, X_teams, y_true, scoring='accuracy')
print("准确率: {0:.1f}%".format(np.mean(scores) * 100))
准确率: 61.2%
from sklearn.ensemble import RandomForestClassifier
clf = RandomForestClassifier(random_state=14)
scores = cross_val_score(clf, X_teams, y_true, scoring='accuracy')
print("Using full team labels is ranked higher")
print("准确率: {0:.1f}%".format(np.mean(scores) * 100))
Using full team labels is ranked higher 准确率: 60.5%
X_all = np.hstack([X_home_higher, X_teams])
print(X_all.shape)
(1229, 62)
clf = RandomForestClassifier(random_state=14)
scores = cross_val_score(clf, X_all, y_true, scoring='accuracy')
print("Using whether the home team is ranked higher")
print("准确率: {0:.1f}%".format(np.mean(scores) * 100))
Using whether the home team is ranked higher 准确率: 60.9%
#n_estimators=10, criterion='gini', max_depth=None,
#min_samples_split=2, min_samples_leaf=1,
#max_features='auto',
#max_leaf_nodes=None, bootstrap=True,
#oob_score=False, n_jobs=1,
#random_state=None, verbose=0, min_density=None, compute_importances=None
parameter_space = {
"max_features": [2, 10, 'auto'],
"n_estimators": [100,],
"criterion": ["gini", "entropy"],
"min_samples_leaf": [2, 4, 6],
}
clf = RandomForestClassifier(random_state=14)
grid = GridSearchCV(clf, parameter_space)
grid.fit(X_all, y_true)
print("准确率: {0:.1f}%".format(grid.best_score_ * 100))
print(grid.best_estimator_)
准确率: 63.8% RandomForestClassifier(bootstrap=True, class_weight=None, criterion='gini', max_depth=None, max_features='auto', max_leaf_nodes=None, min_impurity_split=1e-07, min_samples_leaf=6, min_samples_split=2, min_weight_fraction_leaf=0.0, n_estimators=100, n_jobs=1, oob_score=False, random_state=14, verbose=0, warm_start=False)