In this demo we will show how we can predict if a stock is going up or down based on the news headlines.
The solution consists of the following components
We import the necessary libraries with the help of Maven
%classpath config resolver maven-public1 http://nuc.local:8081/repository/maven-public/
%%classpath add mvn
org.apache.spark:spark-sql_2.11:2.3.2
org.apache.spark:spark-mllib_2.11:2.3.2
ch.pschatzmann:news-digest:LATEST
ch.pschatzmann:investor:LATEST
com.github.habernal:confusion-matrix:1.0
And we import the necessary packages or Classes,
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions._
import org.apache.spark.ml.feature.RegexTokenizer
import org.apache.spark.ml.feature.HashingTF
import org.apache.spark.ml.classification._
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.feature._
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import ch.pschatzmann.news._
import ch.pschatzmann.stocks._
import ch.pschatzmann.stocks.ta4j.indicator._
import ch.pschatzmann.stocks.integration.HistoricValues
import ch.pschatzmann.stocks.accounting._
import org.ta4j.core.indicators._;
import org.ta4j.core.indicators.helpers._;
import com.github.habernal.confusionmatrix._
import org.apache.spark.sql.SparkSession import org.apache.spark.sql.functions._ import org.apache.spark.ml.feature.RegexTokenizer import org.apache.spark.ml.feature.HashingTF import org.apache.spark.ml.classification._ import org.apache.spark.ml.Pipeline import org.apache.spark.ml.feature._ import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator import ch.pschatzmann.news._ import ch.pschatzmann.stocks._ import ch.pschatzmann.stocks.ta4j.indicator._ import ch.pschatzmann.stocks.integration.HistoricValues import ch.pschatzmann.stocks.accounting._ import org.ta4j.core.indicators._ import org.ta4j.core.indicators.helpers._ import com.github.habernal.confusionmatrix._
Finally we create a Spark session and context.
val spark = SparkSession.builder()
.appName("SolrQuery")
.master("local[*]")
.config("spark.ui.enabled", "false")
.getOrCreate()
val sc = spark.sparkContext
org.apache.spark.SparkContext@e6b0028
In order to build our features, we search for all news articles for the MSFT ticker: we select
import spark.implicits._
import scala.collection.JavaConverters._
val ticker = "MSFT"
val query = Utils.companyNameByTickerSearch(ticker)
val store = new SolrDocumentStore()
val rdd = sc.parallelize(store.pagedDocuments(query).asScala)
.map(page => page.values.asScala)
.flatMap(l => l)
val df = spark.createDataFrame(rdd, classOf[Document])
val newsDataset = df.withColumn("date", expr("substring(publishDate_t, 1, 10)"))
.select("date","content_t")
newsDataset.show
+----------+--------------------+ | date| content_t| +----------+--------------------+ |2000-01-15|Microsoft boss ov...| |2000-01-23|For your pleasure...| |2000-02-01|Microsoft finance...| |2000-02-02|365 Corporation p...| |2000-02-17|Windows open on a...| |2000-03-21|Microsoft to sell...| |2000-03-24|Judgment day fill...| |2000-03-25|Cisco eclipses Mi...| |2000-04-02|Microsoft talks c...| |2000-04-05|Markets swoon aft...| |2000-04-12|Microsoft hires B...| |2000-04-21|Microsoft victors...| |2000-04-25|Microsoft plunges...| |2000-04-25|Microsoft issue s...| |2000-05-01|Microsoft bounces...| |2000-05-11|World Online hire...| |2000-05-25|Microsoft falls o...| |2000-06-07|Judge orders Micr...| |2000-06-28|Microsoft and AT&...| |2000-07-06|Microsoft offers ...| +----------+--------------------+ only showing top 20 rows
org.apache.spark.sql.SparkSession$implicits$@734f1aa9
We plan to use the Investor SignIndicator: We create the String labels for the indicated numeric values.
val descriptions = Seq((0.0,"neutral"),(1.0 ,"positive"),(-1.0 ,"negative"))
val descriptionsDF = sc.parallelize(descriptions).toDF("value","labelDescr")
descriptionsDF.show
+-----+----------+ |value|labelDescr| +-----+----------+ | 0.0| neutral| | 1.0| positive| | -1.0| negative| +-----+----------+
null
We determine the label values for all dates with the help of the SignIndicator which is calculated from the difference of the closing price and the closing price of the future. If the value is 1.0 it will go up in the future, if it is -1.0 it will go down.
We convert the data into a dataframe and join it with the labels
var stockData = Context.getStockData(ticker)
var series = stockData.toTimeSeries()
var close = new ClosePriceIndicator(series)
var closePrior = new OffsetIndicator(close, -1)
var closeNext = new OffsetIndicator(close, +5)
var sentimentHistoricValues = new SignIndicator(new DifferenceIndicator(closeNext, closePrior)).toHistoricValues()
val dateFormat = new java.text.SimpleDateFormat("yyyy-MM-dd")
var sentimentDF = sc.parallelize(sentimentHistoricValues
.list().asScala
.map(r => (dateFormat.format(r.getDate), r.getValue)))
.toDF("date","value")
sentimentDF = sentimentDF.join(descriptionsDF,Seq("value"))
sentimentDF.printSchema
root |-- value: double (nullable = true) |-- date: string (nullable = true) |-- labelDescr: string (nullable = true)
null
sentimentDF.show
+-----+----------+----------+ |value| date|labelDescr| +-----+----------+----------+ | 0.0|1986-03-06| neutral| | 0.0|1986-03-14| neutral| | 0.0|1986-03-24| neutral| | 0.0|1986-04-18| neutral| | 0.0|1986-04-25| neutral| | 0.0|1986-05-05| neutral| | 0.0|1986-06-20| neutral| | 0.0|1986-07-15| neutral| | 0.0|1986-08-20| neutral| | 0.0|1987-01-28| neutral| | 0.0|1987-03-25| neutral| | 0.0|1987-07-07| neutral| | 0.0|1987-07-14| neutral| | 0.0|1987-07-15| neutral| | 0.0|1988-01-19| neutral| | 0.0|1988-02-17| neutral| | 0.0|1988-06-16| neutral| | 0.0|1988-09-01| neutral| | 0.0|1988-11-25| neutral| | 0.0|1989-01-09| neutral| +-----+----------+----------+ only showing top 20 rows
Our final input dataset consists of the newsDataset which is joined with the sentimentDF which contains the following colums
val resultDf = newsDataset.join(sentimentDF,Seq("date")).select("date","content_t","labelDescr")
resultDf.show
+----------+--------------------+----------+ | date| content_t|labelDescr| +----------+--------------------+----------+ |2000-11-13|WHEN online broke...| negative| |2001-11-16|Rebounding from a...| positive| |2001-11-16|Oil Prices Drop S...| positive| |2001-11-16|The General Motor...| positive| |2001-11-16|The Telstra Corpo...| positive| |2001-11-16|On a day when oil...| positive| |2001-11-16|To the Editor: La...| positive| |2001-11-16|A federal appeals...| positive| |2001-11-16|The Mitsubishi Co...| positive| |2002-05-13|The Securities an...| negative| |2002-05-13|IF anyone has des...| negative| |2002-05-13|The last of the w...| negative| |2002-05-13|In the Fortune 50...| negative| |2002-05-13|WERTHEIMER-Franc....| negative| |2002-06-21|The decision by a...| negative| |2002-06-21|FALKINBURG-John N...| negative| |2002-06-21|Verizon Communica...| negative| |2002-06-21|The board of the ...| negative| |2002-06-21|WHITWELL-Joseph E...| negative| |2002-06-21|The information-g...| negative| +----------+--------------------+----------+ only showing top 20 rows
null
Now we can split the data into a training and test frame.
val Array(training, test) = resultDf.randomSplit(Array(0.9, 0.1), seed = 12345)
s"training: ${training.count} / test: ${test.count}"
training: 12353 / test: 1397
We user the StringIndexer to convert the label to a number and the HashingTF to convert our headlines to a vector. The model is trained by calling fit and we get a prediction by calling transform on the model.
val indexer = new StringIndexer()
.setInputCol("labelDescr")
.setOutputCol("label")
.fit(resultDf)
val converter = new IndexToString()
.setInputCol("prediction")
.setOutputCol("predictedLabelDescr")
.setLabels(indexer.labels)
val tokenizer = new RegexTokenizer()
.setInputCol("content_t")
.setOutputCol("words")
val hashingTF = new HashingTF()
.setInputCol(tokenizer.getOutputCol)
.setOutputCol("features")
.setNumFeatures(5000)
val classifier = new LogisticRegression()
.setMaxIter(20)
.setRegParam(0.01)
val pipeline = new Pipeline().setStages(Array(tokenizer, indexer, hashingTF, classifier, converter))
val model = pipeline.fit(training)
val trainPredictions = model.transform(training)
val testPredictions = model.transform(test)
[date: string, content_t: string ... 8 more fields]
Here is the result of the predictions from our test dataset:
testPredictions.select("date","labelDescr","predictedLabelDescr").show
+----------+----------+-------------------+ | date|labelDescr|predictedLabelDescr| +----------+----------+-------------------+ |2001-11-16| positive| positive| |2001-11-16| positive| positive| |2009-12-04| positive| positive| |2000-07-11| negative| negative| |2000-10-25| positive| positive| |2000-10-25| positive| negative| |2000-10-25| positive| negative| |2004-12-03| negative| negative| |2005-02-01| negative| negative| |2006-03-28| negative| positive| |2007-02-02| negative| negative| |2009-10-01| positive| negative| |2001-05-04| negative| negative| |2002-08-28| negative| positive| |2002-12-31| positive| negative| |2004-12-02| negative| negative| |2006-03-30| negative| positive| |2008-02-29| positive| negative| |2008-09-25| negative| negative| |2010-06-15| negative| positive| +----------+----------+-------------------+ only showing top 20 rows
testPredictions.printSchema
root |-- date: string (nullable = true) |-- content_t: string (nullable = true) |-- labelDescr: string (nullable = true) |-- words: array (nullable = true) | |-- element: string (containsNull = true) |-- label: double (nullable = false) |-- features: vector (nullable = true) |-- rawPrediction: vector (nullable = true) |-- probability: vector (nullable = true) |-- prediction: double (nullable = false) |-- predictedLabelDescr: string (nullable = true)
We can calculate the accuracy...
val evaluator = new MulticlassClassificationEvaluator()
.setLabelCol("label")
.setPredictionCol("prediction")
.setMetricName("accuracy")
val accuracy = evaluator.evaluate(testPredictions)
s"Accuracy = $accuracy"
Accuracy = 0.5204008589835362
... and the Confusion Matrix
val cm = new ConfusionMatrix()
testPredictions.collect.foreach(l => cm.increaseValue(l.getString(2),l.getString(9),1))
cm
↓gold\pred→ negative neutral positive negative 375 0 305 neutral 1 0 3 positive 361 0 352
We calculate the accuracy of the prediction based on different number of days for the calculation if the stock goes up or down.
import scala.collection.JavaConverters._
import spark.implicits._
class StockPrediction(ticker:String, pipeline:Pipeline) {
// calculate the accuracy
def accuracy(days:Int):Double = {
val query = Utils.companyNameByTickerSearch(ticker)
val store = new SolrDocumentStore()
val rdd = sc.parallelize(store.pagedDocuments(query).asScala)
.map(page => page.values.asScala)
.flatMap(l => l)
val df = spark.createDataFrame(rdd, classOf[Document])
val newsDataset = df.withColumn("date", expr("substring(publishDate_t, 1, 10)"))
.select("date","content_t")
var stockData = Context.getStockData(ticker)
var series = stockData.toTimeSeries()
var close = new ClosePriceIndicator(series)
var closePrior = new OffsetIndicator(close, -1)
var closeNext = new OffsetIndicator(close, +days)
var sentimentHistoricValues = new SignIndicator(new DifferenceIndicator(closeNext, closePrior)).toHistoricValues()
val dateFormat = new java.text.SimpleDateFormat("yyyy-MM-dd")
var sentimentDF = sc.parallelize(sentimentHistoricValues
.list().asScala
.map(r => (dateFormat.format(r.getDate), r.getValue)))
.toDF("date","value")
sentimentDF = sentimentDF.join(descriptionsDF,Seq("value"))
val resultDf = newsDataset.join(sentimentDF,Seq("date")).select("date","content_t","labelDescr")
val Array(training, test) = resultDf.randomSplit(Array(0.9, 0.1), seed = 12345)
val model = pipeline.fit(training)
val testPredictions = model.transform(test)
val evaluator = new MulticlassClassificationEvaluator()
.setLabelCol("label")
.setPredictionCol("prediction")
.setMetricName("accuracy")
val accuracy = evaluator.evaluate(testPredictions)
return accuracy
}
}
org.apache.spark.sql.SparkSession$implicits$@734f1aa9
new StockPrediction(ticker, pipeline).accuracy(5)
0.5204008589835362
We calculate the accuracy for 1 to 100 days and display the result as chart
val days = (1 to 100)
val prediction = new StockPrediction(ticker, pipeline)
val accuracies = days.map(n => prediction.accuracy(n))
[[0.5053686471009305, 0.5089477451682176, 0.5404438081603435, 0.521832498210451, 0.5204008589835362, 0.5397279885468862, 0.49964209019327127, 0.5189692197566214, 0.5497494631352899, 0.5318539727988547, 0.5311381531853973, 0.5404438081603435, 0.5161059413027917, 0.506084466714388, 0.513242662848962, 0.5304223335719399, 0.5518969219756621, 0.5375805297065139, 0.5289906943450251, 0.5561918396564066, 0.5583392984967788, 0.5590551181102362, 0.5547602004294918, 0.5497494631352899, 0.5390121689334287, 0.5497494631352899, 0.5390121689334287, 0.5340014316392269, 0.5325697924123121, 0.5068002863278454, 0.5211166785969935, 0.5246957766642806, 0.5425912670007158, 0.5447387258410881, 0.5340014316392269, 0.5497494631352899, 0.5576234788833214, 0.5590551181102362, 0.5626342161775233, 0.5447387258410881, 0.5576234788833214, 0.5318539727988547, 0.5390121689334287, 0.5476020042949177, 0.5461703650680029, 0.5476020042949177, 0.5612025769506085, 0.5669291338582677, 0.5511811023622047, 0.560486757337151, 0.5497494631352899, 0.560486757337151, 0.5483178239083751, 0.560486757337151, 0.5483178239083751, 0.5597709377236937, 0.56907659269864, 0.5612025769506085, 0.5447387258410881, 0.5597709377236937, 0.5404438081603435, 0.521832498210451, 0.5347172512526843, 0.5676449534717252, 0.5597709377236937, 0.5354330708661418, 0.5368647100930566, 0.5440229062276306, 0.5261274158911954, 0.5590551181102362, 0.5340014316392269, 0.5375805297065139, 0.5361488904795991, 0.5425912670007158, 0.5110952040085899, 0.5289906943450251, 0.5182534001431639, 0.541159627773801, 0.5504652827487473, 0.5418754473872585, 0.5239799570508232, 0.5368647100930566, 0.5361488904795991, 0.5433070866141733, 0.5347172512526843, 0.5175375805297066, 0.5547602004294918, 0.5447387258410881, 0.5361488904795991, 0.5297065139584825, 0.5425912670007158, 0.5418754473872585, 0.5382963493199714, 0.5368647100930566, 0.5397279885468862, 0.5476020042949177, 0.5518969219756621, 0.5425912670007158, 0.5397279885468862, 0.5533285612025769]]
val plot = new Plot
plot.add(new Line { x = days; y = accuracies })