Keras 模型和Python Scikit-Learn Library 结合,学习如何
在Python中,Keras是用于深度学习的库,虽然受欢迎,但是Keras只关注深度学习。Keras追求极简,只关注我们快速使用、定义并建立深度学习模型。
Scikit-learn 是Python的另一个库,建立在用于有效数值计算的SciPy基础之上,常用于机器学习。
Keras中为深度学习模型能在scikit-learn中用于分类和回归提供了便捷的封装器(Wrapper),如KearsClassifier
wrapper(用于在Keras中建立的神经网络分类器)。
# 通过sklearn交叉验证评估用于糖尿病预测的神经网络模型
from keras.models import Sequential
from keras.layers import Dense
from keras.wrappers.scikit_learn import KerasClassifier
from sklearn.model_selection import StratifiedKFold
from sklearn.model_selection import cross_val_score
import numpy as np
import urllib
# 函数用于建立模型,KerasClassifier需要的函数
def create_model():
# 建立模型
model = Sequential()
model.add(Dense(12, input_dim=8, init='uniform', activation='relu'))
model.add(Dense(8, init='uniform', activation='relu'))
model.add(Dense(1, init='uniform', activation='sigmoid'))
# 编译模型
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
return model
# 随机数设置,便于产生相同的随机数
seed = 42
np.random.seed(seed)
# 加载数据
# 使用url来获取 diabetes dataset
url = "http://archive.ics.uci.edu/ml/machine-learning-databases/pima-indians-diabetes/pima-indians-diabetes.data"
# 打开文件
raw_data = urllib.urlopen(url)
# 下载CSV文件 保存为np matrix格式
dataset = np.loadtxt(raw_data, delimiter=',')
# 数据特征与标签分开
X = dataset[:, 0:8]
y = dataset[:, 8]
# 建立模型
model = KerasClassifier(build_fn=create_model, nb_epoch=150, batch_size=10, verbose=0)
# 利用10-fold 交叉验证建立的模型
kfold = StratifiedKFold(n_splits=10, shuffle=True, random_state=seed)
results = cross_val_score(model, X, y, cv=kfold)
print("accuracy: {0} %".format(results.mean()*100))
D:\ProgramData\Anaconda2\lib\site-packages\ipykernel\__main__.py:17: UserWarning: Update your `Dense` call to the Keras 2 API: `Dense(12, activation="relu", kernel_initializer="uniform", input_dim=8)` D:\ProgramData\Anaconda2\lib\site-packages\ipykernel\__main__.py:18: UserWarning: Update your `Dense` call to the Keras 2 API: `Dense(8, activation="relu", kernel_initializer="uniform")` D:\ProgramData\Anaconda2\lib\site-packages\ipykernel\__main__.py:19: UserWarning: Update your `Dense` call to the Keras 2 API: `Dense(1, activation="sigmoid", kernel_initializer="uniform")`
accuracy: 67.0454549068 %
对比上一篇Python深度学习实战03-模型性能评估中的交叉验证,这里的更加快捷。
从上面我们知道Keras
中就配有与scikit-learn
库结合使用的功能(封装包,wrapper),如KerasClassifier
,和scikit-learn
结合使用,这个非常棒。另外我们还可以使用scikit-learn
中的grid search
来找到较佳的神经网络参数,不过这些参数仅包括:
在这里,我们没有 search 神经网络的拓扑结构参数:网络层数,中间隐含层神经元(节点)个数。神经网络拓扑结构确定是一个难点。
下面的代码展示了利用scikit-learn
格点搜索技术调节深度学习模型的参数。
利用 GridSearchCV
功能,我们可以看到神经网络可调节参数为:
optimizers = ['rmsprop', 'adam']
init = ['glorot_uniform', 'normal', 'uniform']
epochs = np.array([50, 100, 150])
batches = np.array([5, 10, 20])
这样就有2x3x3x3=54
中不同参数的神经网络。如果数据集量非常大,那么search起来,计算时间相对样本量小的长很多。
# 通过sklearn 格点搜索来调节 用于糖尿病预测神经网络模型的参数
from keras.models import Sequential
from keras.layers import Dense
from keras.wrappers.scikit_learn import KerasClassifier
from sklearn.model_selection import GridSearchCV
import numpy as np
import urllib
# 函数用于建立模型,KerasClassifier需要的函数
## 选择rmsprop优化器,初始化权重方式为 glorot_uniform
def create_model(optimizer='rmsprop', init='glorot_uniform'):
# 创建模型
model = Sequential()
model.add(Dense(12, input_dim=8, init=init, activation='relu'))
model.add(Dense(8, init=init, activation='relu'))
model.add(Dense(1, init=init, activation='sigmoid'))
# 编译模型
model.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['accuracy'])
return model
# 随机数设置,便于产生相同的随机数
seed = 7
np.random.seed(seed)
# 加载数据
## 使用url来获取 diabetes dataset
url = "http://archive.ics.uci.edu/ml/machine-learning-databases/pima-indians-diabetes/pima-indians-diabetes.data"
## 打开数据文件
raw_data = urllib.urlopen(url)
## 下载数据文件 保存为np matrix格式
dataset = np.loadtxt(raw_data, delimiter=',')
## 将数据特征与标签分开
X = dataset[:, 0:8]
Y = dataset[:, 8]
# 创建模型
model = KerasClassifier(build_fn=create_model, verbose=0)
# 格点搜索的步数,batch 样本数和优化器 参数设置
optimizers = ['rmsprop', 'adam']
init = ['glorot_uniform', 'normal', 'uniform']
epochs = np.array([50, 100, 150])
batches = np.array([5, 10, 20])
param_grid = dict(optimizer=optimizers, nb_epoch=epochs, batch_size=batches, init=init)
grid = GridSearchCV(estimator=model, param_grid=param_grid)
grid_result = grid.fit(X, Y)
# 结果显示
print("Best: %f using %s" % (grid_result.best_score_, grid_result.best_params_))
for params, mean_score, scores in grid_result.grid_scores_:
print("%f (%f) with: %r" % (scores.mean(), scores.std(), params))
D:\ProgramData\Anaconda2\lib\site-packages\ipykernel\__main__.py:17: UserWarning: Update your `Dense` call to the Keras 2 API: `Dense(12, activation="relu", kernel_initializer="glorot_uniform", input_dim=8)` D:\ProgramData\Anaconda2\lib\site-packages\ipykernel\__main__.py:18: UserWarning: Update your `Dense` call to the Keras 2 API: `Dense(8, activation="relu", kernel_initializer="glorot_uniform")` D:\ProgramData\Anaconda2\lib\site-packages\ipykernel\__main__.py:19: UserWarning: Update your `Dense` call to the Keras 2 API: `Dense(1, activation="sigmoid", kernel_initializer="glorot_uniform")` D:\ProgramData\Anaconda2\lib\site-packages\ipykernel\__main__.py:17: UserWarning: Update your `Dense` call to the Keras 2 API: `Dense(12, activation="relu", kernel_initializer="normal", input_dim=8)` D:\ProgramData\Anaconda2\lib\site-packages\ipykernel\__main__.py:18: UserWarning: Update your `Dense` call to the Keras 2 API: `Dense(8, activation="relu", kernel_initializer="normal")` D:\ProgramData\Anaconda2\lib\site-packages\ipykernel\__main__.py:19: UserWarning: Update your `Dense` call to the Keras 2 API: `Dense(1, activation="sigmoid", kernel_initializer="normal")` D:\ProgramData\Anaconda2\lib\site-packages\ipykernel\__main__.py:17: UserWarning: Update your `Dense` call to the Keras 2 API: `Dense(12, activation="relu", kernel_initializer="uniform", input_dim=8)` D:\ProgramData\Anaconda2\lib\site-packages\ipykernel\__main__.py:18: UserWarning: Update your `Dense` call to the Keras 2 API: `Dense(8, activation="relu", kernel_initializer="uniform")` D:\ProgramData\Anaconda2\lib\site-packages\ipykernel\__main__.py:19: UserWarning: Update your `Dense` call to the Keras 2 API: `Dense(1, activation="sigmoid", kernel_initializer="uniform")`
Best: 0.697917 using {'init': 'normal', 'optimizer': 'adam', 'nb_epoch': 150, 'batch_size': 5} 0.636719 (0.024910) with: {'init': 'glorot_uniform', 'optimizer': 'rmsprop', 'nb_epoch': 50, 'batch_size': 5} 0.665365 (0.021236) with: {'init': 'glorot_uniform', 'optimizer': 'adam', 'nb_epoch': 50, 'batch_size': 5} 0.643229 (0.030647) with: {'init': 'glorot_uniform', 'optimizer': 'rmsprop', 'nb_epoch': 100, 'batch_size': 5} 0.670573 (0.020752) with: {'init': 'glorot_uniform', 'optimizer': 'adam', 'nb_epoch': 100, 'batch_size': 5} 0.644531 (0.020915) with: {'init': 'glorot_uniform', 'optimizer': 'rmsprop', 'nb_epoch': 150, 'batch_size': 5} 0.651042 (0.027498) with: {'init': 'glorot_uniform', 'optimizer': 'adam', 'nb_epoch': 150, 'batch_size': 5} 0.678385 (0.027126) with: {'init': 'normal', 'optimizer': 'rmsprop', 'nb_epoch': 50, 'batch_size': 5} 0.675781 (0.020915) with: {'init': 'normal', 'optimizer': 'adam', 'nb_epoch': 50, 'batch_size': 5} 0.682292 (0.004872) with: {'init': 'normal', 'optimizer': 'rmsprop', 'nb_epoch': 100, 'batch_size': 5} 0.688802 (0.013279) with: {'init': 'normal', 'optimizer': 'adam', 'nb_epoch': 100, 'batch_size': 5} 0.670573 (0.015733) with: {'init': 'normal', 'optimizer': 'rmsprop', 'nb_epoch': 150, 'batch_size': 5} 0.697917 (0.007366) with: {'init': 'normal', 'optimizer': 'adam', 'nb_epoch': 150, 'batch_size': 5} 0.661458 (0.031466) with: {'init': 'uniform', 'optimizer': 'rmsprop', 'nb_epoch': 50, 'batch_size': 5} 0.694010 (0.025780) with: {'init': 'uniform', 'optimizer': 'adam', 'nb_epoch': 50, 'batch_size': 5} 0.628906 (0.048265) with: {'init': 'uniform', 'optimizer': 'rmsprop', 'nb_epoch': 100, 'batch_size': 5} 0.669271 (0.009207) with: {'init': 'uniform', 'optimizer': 'adam', 'nb_epoch': 100, 'batch_size': 5} 0.677083 (0.020505) with: {'init': 'uniform', 'optimizer': 'rmsprop', 'nb_epoch': 150, 'batch_size': 5} 0.684896 (0.019225) with: {'init': 'uniform', 'optimizer': 'adam', 'nb_epoch': 150, 'batch_size': 5} 0.526042 (0.110685) with: {'init': 'glorot_uniform', 'optimizer': 'rmsprop', 'nb_epoch': 50, 'batch_size': 10} 0.583333 (0.086547) with: {'init': 'glorot_uniform', 'optimizer': 'adam', 'nb_epoch': 50, 'batch_size': 10} 0.467448 (0.162891) with: {'init': 'glorot_uniform', 'optimizer': 'rmsprop', 'nb_epoch': 100, 'batch_size': 10} 0.529948 (0.145472) with: {'init': 'glorot_uniform', 'optimizer': 'adam', 'nb_epoch': 100, 'batch_size': 10} 0.552083 (0.123869) with: {'init': 'glorot_uniform', 'optimizer': 'rmsprop', 'nb_epoch': 150, 'batch_size': 10} 0.640625 (0.022326) with: {'init': 'glorot_uniform', 'optimizer': 'adam', 'nb_epoch': 150, 'batch_size': 10} 0.654948 (0.027126) with: {'init': 'normal', 'optimizer': 'rmsprop', 'nb_epoch': 50, 'batch_size': 10} 0.680990 (0.014382) with: {'init': 'normal', 'optimizer': 'adam', 'nb_epoch': 50, 'batch_size': 10} 0.667969 (0.034499) with: {'init': 'normal', 'optimizer': 'rmsprop', 'nb_epoch': 100, 'batch_size': 10} 0.670573 (0.012890) with: {'init': 'normal', 'optimizer': 'adam', 'nb_epoch': 100, 'batch_size': 10} 0.675781 (0.011049) with: {'init': 'normal', 'optimizer': 'rmsprop', 'nb_epoch': 150, 'batch_size': 10} 0.666667 (0.009207) with: {'init': 'normal', 'optimizer': 'adam', 'nb_epoch': 150, 'batch_size': 10} 0.673177 (0.023073) with: {'init': 'uniform', 'optimizer': 'rmsprop', 'nb_epoch': 50, 'batch_size': 10} 0.678385 (0.003683) with: {'init': 'uniform', 'optimizer': 'adam', 'nb_epoch': 50, 'batch_size': 10} 0.680990 (0.008027) with: {'init': 'uniform', 'optimizer': 'rmsprop', 'nb_epoch': 100, 'batch_size': 10} 0.674479 (0.030145) with: {'init': 'uniform', 'optimizer': 'adam', 'nb_epoch': 100, 'batch_size': 10} 0.664063 (0.009568) with: {'init': 'uniform', 'optimizer': 'rmsprop', 'nb_epoch': 150, 'batch_size': 10} 0.671875 (0.013902) with: {'init': 'uniform', 'optimizer': 'adam', 'nb_epoch': 150, 'batch_size': 10} 0.671875 (0.011500) with: {'init': 'glorot_uniform', 'optimizer': 'rmsprop', 'nb_epoch': 50, 'batch_size': 20} 0.635417 (0.043537) with: {'init': 'glorot_uniform', 'optimizer': 'adam', 'nb_epoch': 50, 'batch_size': 20} 0.636719 (0.022999) with: {'init': 'glorot_uniform', 'optimizer': 'rmsprop', 'nb_epoch': 100, 'batch_size': 20} 0.648438 (0.011500) with: {'init': 'glorot_uniform', 'optimizer': 'adam', 'nb_epoch': 100, 'batch_size': 20} 0.615885 (0.035132) with: {'init': 'glorot_uniform', 'optimizer': 'rmsprop', 'nb_epoch': 150, 'batch_size': 20} 0.651042 (0.027866) with: {'init': 'glorot_uniform', 'optimizer': 'adam', 'nb_epoch': 150, 'batch_size': 20} 0.662760 (0.024774) with: {'init': 'normal', 'optimizer': 'rmsprop', 'nb_epoch': 50, 'batch_size': 20} 0.651042 (0.027866) with: {'init': 'normal', 'optimizer': 'adam', 'nb_epoch': 50, 'batch_size': 20} 0.671875 (0.000000) with: {'init': 'normal', 'optimizer': 'rmsprop', 'nb_epoch': 100, 'batch_size': 20} 0.669271 (0.012890) with: {'init': 'normal', 'optimizer': 'adam', 'nb_epoch': 100, 'batch_size': 20} 0.671875 (0.022326) with: {'init': 'normal', 'optimizer': 'rmsprop', 'nb_epoch': 150, 'batch_size': 20} 0.647135 (0.013279) with: {'init': 'normal', 'optimizer': 'adam', 'nb_epoch': 150, 'batch_size': 20} 0.651042 (0.025780) with: {'init': 'uniform', 'optimizer': 'rmsprop', 'nb_epoch': 50, 'batch_size': 20} 0.654948 (0.019488) with: {'init': 'uniform', 'optimizer': 'adam', 'nb_epoch': 50, 'batch_size': 20} 0.656250 (0.020915) with: {'init': 'uniform', 'optimizer': 'rmsprop', 'nb_epoch': 100, 'batch_size': 20} 0.645833 (0.027126) with: {'init': 'uniform', 'optimizer': 'adam', 'nb_epoch': 100, 'batch_size': 20} 0.656250 (0.027621) with: {'init': 'uniform', 'optimizer': 'rmsprop', 'nb_epoch': 150, 'batch_size': 20} 0.656250 (0.028348) with: {'init': 'uniform', 'optimizer': 'adam', 'nb_epoch': 150, 'batch_size': 20}
D:\ProgramData\Anaconda2\lib\site-packages\sklearn\model_selection\_search.py:667: DeprecationWarning: The grid_scores_ attribute was deprecated in version 0.18 in favor of the more elaborate cv_results_ attribute. The grid_scores_ attribute will not be available from 0.20 DeprecationWarning)
Best: 0.697917 using {'init': 'normal', 'optimizer': 'adam', 'nb_epoch': 150, 'batch_size': 5}
最好的分类结果约为 69.79%。其使用的参数设置,如上所示。