# spark
from pyspark import SparkConf, SparkContext
from pyspark.sql import SQLContext, SparkSession
# pipeline
from pyspark.ml import Pipeline
# model
from pyspark.ml.classification import (RandomForestClassifier,
GBTClassifier,
DecisionTreeClassifier)
# config
conf = SparkConf().setAppName("building a TREE MODEL")
sc = SparkContext(conf=conf)
sqlCtx = SQLContext(sc)
spark = SparkSession.builder.enableHiveSupport().getOrCreate()
data = spark.read.format('libsvm').load('sample_libsvm_data.txt')
data.show()
+-----+--------------------+ |label| features| +-----+--------------------+ | 0.0|(692,[127,128,129...| | 1.0|(692,[158,159,160...| | 1.0|(692,[124,125,126...| | 1.0|(692,[152,153,154...| | 1.0|(692,[151,152,153...| | 0.0|(692,[129,130,131...| | 1.0|(692,[158,159,160...| | 1.0|(692,[99,100,101,...| | 0.0|(692,[154,155,156...| | 0.0|(692,[127,128,129...| | 1.0|(692,[154,155,156...| | 0.0|(692,[153,154,155...| | 0.0|(692,[151,152,153...| | 1.0|(692,[129,130,131...| | 0.0|(692,[154,155,156...| | 1.0|(692,[150,151,152...| | 0.0|(692,[124,125,126...| | 0.0|(692,[152,153,154...| | 1.0|(692,[97,98,99,12...| | 1.0|(692,[124,125,126...| +-----+--------------------+ only showing top 20 rows
train_data, test_data = data.randomSplit([0.7, 0.3])
# build the models
dtc = DecisionTreeClassifier()
rfc = RandomForestClassifier()
gbt = GBTClassifier()
# train
dtc_model = dtc.fit(train_data)
rfc_model = rfc.fit(train_data)
gbt_model = gbt.fit(train_data)
# predict
dtc_preds = dtc_model.transform(test_data)
rfc_preds = rfc_model.transform(test_data)
gbt_preds = gbt_model.transform(test_data)
# show the results
dtc_preds.show()
#gbt_preds.show()
+-----+--------------------+-------------+-----------+----------+ |label| features|rawPrediction|probability|prediction| +-----+--------------------+-------------+-----------+----------+ | 0.0|(692,[95,96,97,12...| [28.0,0.0]| [1.0,0.0]| 0.0| | 0.0|(692,[100,101,102...| [28.0,0.0]| [1.0,0.0]| 0.0| | 0.0|(692,[121,122,123...| [28.0,0.0]| [1.0,0.0]| 0.0| | 0.0|(692,[123,124,125...| [28.0,0.0]| [1.0,0.0]| 0.0| | 0.0|(692,[123,124,125...| [28.0,0.0]| [1.0,0.0]| 0.0| | 0.0|(692,[124,125,126...| [28.0,0.0]| [1.0,0.0]| 0.0| | 0.0|(692,[124,125,126...| [28.0,0.0]| [1.0,0.0]| 0.0| | 0.0|(692,[125,126,127...| [28.0,0.0]| [1.0,0.0]| 0.0| | 0.0|(692,[126,127,128...| [28.0,0.0]| [1.0,0.0]| 0.0| | 0.0|(692,[126,127,128...| [28.0,0.0]| [1.0,0.0]| 0.0| | 0.0|(692,[127,128,129...| [28.0,0.0]| [1.0,0.0]| 0.0| | 0.0|(692,[127,128,129...| [28.0,0.0]| [1.0,0.0]| 0.0| | 0.0|(692,[153,154,155...| [28.0,0.0]| [1.0,0.0]| 0.0| | 0.0|(692,[154,155,156...| [0.0,37.0]| [0.0,1.0]| 1.0| | 0.0|(692,[234,235,237...| [0.0,1.0]| [0.0,1.0]| 1.0| | 1.0|(692,[100,101,102...| [0.0,37.0]| [0.0,1.0]| 1.0| | 1.0|(692,[123,124,125...| [0.0,37.0]| [0.0,1.0]| 1.0| | 1.0|(692,[123,124,125...| [0.0,1.0]| [0.0,1.0]| 1.0| | 1.0|(692,[123,124,125...| [0.0,37.0]| [0.0,1.0]| 1.0| | 1.0|(692,[125,126,153...| [0.0,37.0]| [0.0,1.0]| 1.0| +-----+--------------------+-------------+-----------+----------+ only showing top 20 rows
# MulticlassClassificationEvaluator works on the binary class dataset as well
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
acc_eval = MulticlassClassificationEvaluator(metricName='accuracy')
print ('DTC (DECISION TREE) ACCURACY : ')
acc_eval.evaluate(dtc_preds)
DTC (DECISION TREE) ACCURACY :
0.9117647058823529
print ('RFC (RANDOM FOREST) ACCURACY : ')
acc_eval.evaluate(rfc_preds)
RFC (RANDOM FOREST) ACCURACY :
1.0
# FEATURE IMPORTANCE
rfc_model.featureImportances
SparseVector(692, {100: 0.0036, 185: 0.0029, 272: 0.0416, 292: 0.0026, 295: 0.0026, 300: 0.0107, 317: 0.0393, 322: 0.0027, 325: 0.005, 343: 0.003, 350: 0.045, 351: 0.0443, 355: 0.0023, 359: 0.0026, 374: 0.0696, 377: 0.0792, 379: 0.0471, 401: 0.0299, 403: 0.0027, 406: 0.0477, 411: 0.0039, 415: 0.0084, 426: 0.0057, 428: 0.0447, 434: 0.0618, 455: 0.0475, 456: 0.0107, 457: 0.0113, 462: 0.0471, 463: 0.0467, 464: 0.0034, 490: 0.0464, 491: 0.0061, 510: 0.0162, 511: 0.0393, 512: 0.0033, 517: 0.0452, 526: 0.0121, 540: 0.05, 598: 0.0001, 637: 0.0027, 661: 0.0031})
# end of 13.46
# next : 14