In this document we show how to use Spark's MLib Machine Learning functionality using the RandomForest classifier on the IRIS data. Most of the examples that I have found are loading the data from a file. In my example I load it from the Internet, so the only thing you need to execute the example is a working network connection.
%%classpath add mvn
org.apache.spark:spark-sql_2.11:2.3.2
org.apache.spark:spark-mllib_2.11:2.3.2
import org.apache.spark.sql.SparkSession
val spark = SparkSession.builder()
.appName("Iris NaiveBayes")
.master("local")
.config("spark.ui.enabled", "false")
.getOrCreate()
org.apache.spark.sql.SparkSession@36de8728
We load the data form the Internet with the help of a url.
import java.net.URL
import scala.io.Source
import spark.implicits._
var url = "https://gist.githubusercontent.com/netj/8836201/raw/6f9306ad21398ea43cba4f7d537619d0e07d5ae3/iris.csv"
val streamString = Source.fromURL(new URL(url)).mkString
val csvList = streamString.lines.toList
val in = spark.read
.option("header", "true")
.option("inferSchema", "true")
.csv(csvList.toDS())
in.printSchema()
root |-- sepal.length: double (nullable = true) |-- sepal.width: double (nullable = true) |-- petal.length: double (nullable = true) |-- petal.width: double (nullable = true) |-- variety: string (nullable = true)
org.apache.spark.sql.SparkSession$implicits$@723693c9
We rename the fiels because the dot is creating issues. We could escape the field names with ´ but to work with proper field names is preferrable:
var data = in
.withColumn("sepalLength", in.col("`sepal.length`").cast("double"))
.withColumn("sepalWidth", in.col("`sepal.width`").cast("double"))
.withColumn("petalLength", in.col("`petal.length`").cast("double"))
.withColumn("petalWidth", in.col("`petal.width`").cast("double"))
.withColumn("label", in.col("variety"))
.drop("sepal.length", "sepal.width","petal.length","petal.width","variety")
data.printSchema()
data.show()
root |-- sepalLength: double (nullable = true) |-- sepalWidth: double (nullable = true) |-- petalLength: double (nullable = true) |-- petalWidth: double (nullable = true) |-- label: string (nullable = true) +-----------+----------+-----------+----------+------+ |sepalLength|sepalWidth|petalLength|petalWidth| label| +-----------+----------+-----------+----------+------+ | 5.1| 3.5| 1.4| 0.2|Setosa| | 4.9| 3.0| 1.4| 0.2|Setosa| | 4.7| 3.2| 1.3| 0.2|Setosa| | 4.6| 3.1| 1.5| 0.2|Setosa| | 5.0| 3.6| 1.4| 0.2|Setosa| | 5.4| 3.9| 1.7| 0.4|Setosa| | 4.6| 3.4| 1.4| 0.3|Setosa| | 5.0| 3.4| 1.5| 0.2|Setosa| | 4.4| 2.9| 1.4| 0.2|Setosa| | 4.9| 3.1| 1.5| 0.1|Setosa| | 5.4| 3.7| 1.5| 0.2|Setosa| | 4.8| 3.4| 1.6| 0.2|Setosa| | 4.8| 3.0| 1.4| 0.1|Setosa| | 4.3| 3.0| 1.1| 0.1|Setosa| | 5.8| 4.0| 1.2| 0.2|Setosa| | 5.7| 4.4| 1.5| 0.4|Setosa| | 5.4| 3.9| 1.3| 0.4|Setosa| | 5.1| 3.5| 1.4| 0.3|Setosa| | 5.7| 3.8| 1.7| 0.3|Setosa| | 5.1| 3.8| 1.5| 0.3|Setosa| +-----------+----------+-----------+----------+------+ only showing top 20 rows
null
Finally we split the data into a training and test dataset and this concludes our data preparation.
// Split the data into training and test sets (20% held out for testing)
val Array(trainingData, testData) = data.randomSplit(Array(0.8, 0.2), seed = 1234L)
[sepalLength: double, sepalWidth: double ... 3 more fields]
The features need to be vectorized. We can do this with the VectorAssembler. The labels are still Strings and we need to convert them to a numeric value. We do this with the StringIndexer. We classify the data with a RandomForestClassifier and convert the predicted data back to a string with IndexToString.
All these steps are collected in a Pipline which we use to fit and to predict (transform):
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.ml.feature.{IndexToString, StringIndexer, VectorIndexer}
import org.apache.spark.ml.classification.RandomForestClassifier
import org.apache.spark.ml.feature.VectorAssembler
// Build the Feature Vector
val vectorAssembler = new VectorAssembler()
.setInputCols(Array("sepalLength", "sepalWidth", "petalLength","petalWidth"))
.setOutputCol("features")
// Index labels, adding metadata to the label column.
// Fit on whole dataset to include all labels in index.
val labelIndexer = new StringIndexer()
.setInputCol("label")
.setOutputCol("indexedLabel")
.fit(data)
// Train a RandomForest model.
val classifier = new NaiveBayes()
.setLabelCol("indexedLabel")
.setFeaturesCol("features")
// Convert indexed labels back to original labels.
val labelConverter = new IndexToString()
.setInputCol("prediction")
.setOutputCol("predictedLabel")
.setLabels(labelIndexer.labels)
// Chain indexers and forest in a Pipeline.
val pipeline = new Pipeline()
.setStages(Array(vectorAssembler, labelIndexer, classifier, labelConverter))
// Train model. This also runs the indexers.
val model = pipeline.fit(trainingData)
// Make predictions.
val predictions = model.transform(testData)
//predictions.show
predictions.select("predictedLabel", "label", "features").show(10)
+--------------+----------+-----------------+ |predictedLabel| label| features| +--------------+----------+-----------------+ | Setosa| Setosa|[4.3,3.0,1.1,0.1]| | Setosa| Setosa|[4.4,2.9,1.4,0.2]| | Setosa| Setosa|[4.4,3.0,1.3,0.2]| | Setosa| Setosa|[4.8,3.1,1.6,0.2]| | Setosa| Setosa|[5.0,3.3,1.4,0.2]| | Setosa| Setosa|[5.0,3.4,1.5,0.2]| | Setosa| Setosa|[5.0,3.6,1.4,0.2]| | Setosa| Setosa|[5.1,3.4,1.5,0.2]| | Versicolor|Versicolor|[5.2,2.7,3.9,1.4]| | Setosa| Setosa|[5.2,4.1,1.5,0.1]| +--------------+----------+-----------------+ only showing top 10 rows
null
We need to do the evaluation of the accuracy of our model on the numerical labels. Here is the current Schema:
predictions.printSchema
root |-- sepalLength: double (nullable = true) |-- sepalWidth: double (nullable = true) |-- petalLength: double (nullable = true) |-- petalWidth: double (nullable = true) |-- label: string (nullable = true) |-- features: vector (nullable = true) |-- indexedLabel: double (nullable = false) |-- rawPrediction: vector (nullable = true) |-- probability: vector (nullable = true) |-- prediction: double (nullable = false) |-- predictedLabel: string (nullable = true)
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
// Select (prediction, true label) and compute test error
val evaluator = new MulticlassClassificationEvaluator()
.setLabelCol("indexedLabel")
.setPredictionCol("prediction")
.setMetricName("accuracy")
val accuracy = evaluator.evaluate(predictions)
0.9583333333333334