DL4J - Stock Forecasting with LSTM

In this demo we show how to forecast the closing price of AAPL using a LSTM RRN network. We use the open, closing, high and low rates and the volume of the current day as input in order to predict the subsequent closing price.

This demo has been implemented in Scala using Jupyter with the BeakerX kernel using the following libraries

  • Investor
  • DL4j

Setup

We add the necessary java libraries with the help of Maven...

In [17]:
%classpath config resolver maven-public1 http://nuc.local:8081/repository/maven-public/
%%classpath add mvn 
ch.pschatzmann:investor:LATEST
ch.pschatzmann:investor-dl4j:LATEST
org.nd4j:nd4j-native:1.0.0-beta2
org.deeplearning4j:deeplearning4j-core:1.0.0-beta2

... and we import all relevant packages

In [18]:
import org.ta4j.core.Indicator
import org.ta4j.core.num.Num
import ch.pschatzmann.stocks.forecasting._
import ch.pschatzmann.stocks.Context
import ch.pschatzmann.stocks.ta4j.indicator._
import ch.pschatzmann.stocks.integration.dl4j._
import ch.pschatzmann.stocks.integration.StockTimeSeries
import ch.pschatzmann.stocks.integration.HistoricValues
import ch.pschatzmann.stocks.data.index.SP500Index
import ch.pschatzmann.display.Table

import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.eval._;
import org.deeplearning4j.nn.conf._
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers._;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.datasets.datavec._;
import org.deeplearning4j.evaluation._
import org.deeplearning4j.nn.api.OptimizationAlgorithm
import org.datavec.api.records.reader.RecordReader
import org.nd4j.linalg.activations.Activation
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator
import org.nd4j.linalg.learning.config._
import org.nd4j.linalg.lossfunctions.LossFunctions
import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
import org.nd4j.linalg.dataset.api.preprocessor._
import org.nd4j.linalg.dataset.api._
import org.ta4j.core.indicators.helpers._
import org.ta4j.core.Indicator
import org.ta4j.core.num.Num

import scala.collection.mutable.ListBuffer
import scala.collection.Map
Out[18]:
import org.ta4j.core.Indicator
import org.ta4j.core.num.Num
import ch.pschatzmann.stocks.forecasting._
import ch.pschatzmann.stocks.Context
import ch.pschatzmann.stocks.ta4j.indicator._
import ch.pschatzmann.stocks.integration.dl4j._
import ch.pschatzmann.stocks.integration.StockTimeSeries
import ch.pschatzmann.stocks.integration.HistoricValues
import ch.pschatzmann.stocks.data.index.SP500Index
import ch.pschatzmann.display.Table
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator
import org.deeplearning4j.eval._
import org.deeplearning4j.nn.conf._
import org.deeplearning4j.nn.conf.inputs.InputType
import org.deeplearning4j.nn.conf.layers._
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork
import org.deeplearning4j.nn.weights.WeightInit
import org.deeplearning4...

Data Generation

We load APPL as our input data with the help of the Investor framework and define the input and output as list of Indicators. The input data is determined by shifting the data by -1 day with the help of the OffsetIndicator class.

In [19]:
var timeSeries = Context.getStockData("AAPL").toTimeSeries()

// get actual indictors
var open = new OpenPriceIndicator(timeSeries)
var close = new ClosePriceIndicator(timeSeries)
var high = new MaxPriceIndicator(timeSeries)
var low = new MinPriceIndicator(timeSeries)
var volume = new VolumeIndicator(timeSeries)

// create forecast
var openHistory =  new OffsetIndicator(open, -1)
var closeHistory =  new OffsetIndicator(close, -1)
var highHistory =  new OffsetIndicator(high, -1)
var lowHistory =  new OffsetIndicator(low, -1)
var volumeHistory = new OffsetIndicator(volume, -1)

var in:List[org.ta4j.core.Indicator[org.ta4j.core.num.Num]] = List(closeHistory, openHistory, highHistory , lowHistory , volumeHistory) 
var out:List[org.ta4j.core.Indicator[org.ta4j.core.num.Num]] = List(close)

var table = Table.create(close,closeHistory,openHistory,highHistory,lowHistory,volumeHistory)
In [20]:
new SimpleTimePlot {
    data = table.seq()
    columns = Seq("ClosePriceIndicator")
}

Test Data

We split the data and generate a StockData3DIterator for the training data. We use

  • batches of 50 entries.
  • 10 periods as input
  • 10 periods as output (with no masking)
  • a sliding window of 10 periods (so that each batch contains no overlapping data)

The scaling for stock data can be quite tricky: We will scale the data automatically for each iteration separatly. This can be done easily by calling setScalingPerDataset(true) on the StockData3DIterator.

In [21]:
import scala.collection.JavaConverters._

var splitDate = Context.date("2018-01-01")
var inputTrain = IndicatorSplitter.split(in.toList.asJava.asInstanceOf[java.util.List[Indicator[Num]]], splitDate, true)
var outputTrain = IndicatorSplitter.split(out.toList.asJava.asInstanceOf[java.util.List[Indicator[Num]]], splitDate, true)
var iteratorTrain = new StockData3DIterator(inputTrain, outputTrain, 50, 10, 10, 10); 
iteratorTrain.setScalingPerDataset(true)
Out[21]:
null

Here is the output of the first dataset:

In [22]:
iteratorTrain.next()
Out[22]:
===========INPUT===================
[[[    0.7100,    0.6500,    0.5700  ...    0.7950    0.8600,    0.9800], 
  [    0.7085,    0.6533,    0.5729  ...    0.7940    0.8593,    0.9799], 
  [    0.7100,    0.6500,    0.5700  ...    0.7950    0.8600,    0.9800], 
  [    0.7100,    0.6500,    0.5700  ...    0.7950    0.8600,    0.9800], 
  [    1.0000,    0.3694,    0.2185  ...    0.0920    0.0943,    0.1106]], 

 [[    1.0000,    0.9650,    0.9250  ...    0.7700    0.8350,    0.8250], 
  [    1.0000,    0.9698,    0.9296  ...    0.7739    0.8342,    0.8342], 
  [    1.0000,    0.9650,    0.9250  ...    0.7700    0.8350,    0.8300], 
  [    1.0000,    0.9650,    0.9250  ...    0.7700    0.8350,    0.8250], 
  [    0.1914,    0.1392,    0.0679  ...    0.0767    0.0373,    0.0420]], 

 [[    0.7800,    0.7850,    0.8100  ...    0.8750    0.8700,    0.8500], 
  [    0.7839,    0.7839,    0.8090  ...    0.8744    0.8744,    0.8543], 
  [    0.7800,    0.7850,    0.8150  ...    0.8800    0.8750,    0.8500], 
  [    0.7800,    0.7850,    0.8100  ...    0.8750    0.8700,    0.8500], 
  [    0.0406,    0.0218,    0.0213  ...    0.0675    0.0152,    0.0440]], 

  ..., 

 [[    0.5350,    0.5400,    0.5650  ...    0.7900    0.8000,    0.7650], 
  [    0.5930,    0.5327,    0.5377  ...    0.7035    0.7889,    0.7889], 
  [    0.5950,    0.5400,    0.5650  ...    0.7850    0.8300,    0.7850], 
  [    0.5300,    0.4900,    0.5400  ...    0.7050    0.7800,    0.7450], 
  [    0.3889,    0.3519,    0.4023  ...    0.4968    0.6989,    0.2954]], 

 [[    0.7150,    0.7550,    0.8000  ...    0.8150    0.8150,    0.7950], 
  [    0.7638,    0.7136,    0.7588  ...    0.7588    0.8141,    0.8141], 
  [    0.7700,    0.7600,    0.8150  ...    0.8150    0.8300,    0.8200], 
  [    0.7100,    0.7100,    0.7600  ...    0.7600    0.8100,    0.7900], 
  [    0.2474,    0.3778,    0.4273  ...    0.3011    0.3195,    0.2004]], 

 [[    0.6850,    0.7150,    0.7400  ...    0.8600    0.8300,    0.9000], 
  [    0.7940,    0.6985,    0.7136  ...    0.8593    0.8442,    0.8291], 
  [    0.7900,    0.7450,    0.7750  ...    0.8750    0.8400,    0.9050], 
  [    0.6850,    0.7000,    0.7100  ...    0.8400    0.8150,    0.8200], 
  [    0.2088,    0.1814,    0.1497  ...    0.3454    0.0934,    0.3064]]]
=================OUTPUT==================
[[[    0.6500,    0.5700,    0.5950,    0.6250,    0.6900,    0.7450,    0.7950,    0.8600,    0.9800,    1.0000]], 

 [[    0.9650,    0.9250,    0.9400,    0.9100,    0.8500,    0.7950,    0.7700,    0.8350,    0.8250,    0.7800]], 

 [[    0.7850,    0.8100,    0.8000,    0.8750,    0.8350,    0.8600,    0.8750,    0.8700,    0.8500,    0.8400]], 

 [[    0.8000,    0.7550,    0.6900,    0.6250,    0.6650,    0.7050,    0.7050,    0.7100,    0.6500,    0.6500]], 

 [[    0.6150,    0.6050,    0.5800,    0.6050,    0.6500,    0.5850,    0.5300,    0.5450,    0.5100,    0.5700]], 

 [[    0.5850,    0.6200,    0.6250,    0.6100,    0.6000,    0.5950,    0.5850,    0.5050,    0.4600,    0.4250]], 

 [[    0.4600,    0.4500,    0.4850,    0.5300,    0.5900,    0.5800,    0.5900,    0.6300,    0.6250,    0.6050]], 

 [[    0.5850,    0.5500,    0.5500,    0.5400,    0.5300,    0.6150,    0.6200,    0.6000,    0.5900,    0.6400]], 

 [[    0.6600,    0.6750,    0.6750,    0.6750,    0.6200,    0.5600,    0.5900,    0.6600,    0.7000,    0.7300]], 

 [[    0.7200,    0.7100,    0.6900,    0.6750,    0.6950,    0.6950,    0.6900,    0.6850,    0.6550,    0.6700]], 

 [[    0.6800,    0.6550,    0.6550,    0.6500,    0.6350,    0.6600,    0.6800,    0.6600,    0.6950,    0.7600]], 

 [[    0.8150,    0.8100,    0.8800,    0.8800,    0.8850,    0.8850,    0.8200,    0.8200,    0.8450,    0.8250]], 

 [[    0.7800,    0.8050,    0.8200,    0.8750,    0.8600,    0.8550,    0.8300,    0.8100,    0.8050,    0.7700]], 

 [[    0.7250,    0.7450,    0.7150,    0.7400,    0.7250,    0.6850,    0.6000,    0.5900,    0.5900,    0.5550]], 

 [[    0.5650,    0.6050,    0.5250,    0.4500,    0.4700,    0.5100,    0.5350,    0.5600,    0.5950,    0.5250]], 

 [[    0.5200,    0.4650,    0.4900,    0.5200,    0.5600,    0.5250,    0.5100,    0.5450,    0.5600,    0.5500]], 

 [[    0.5650,    0.5950,    0.5700,    0.5700,    0.5400,    0.5200,    0.4900,    0.4750,    0.4450,    0.4250]], 

 [[    0.4150,    0.4250,    0.3650,    0.3150,    0.3350,    0.3200,    0.3250,    0.3650,    0.3650,    0.4150]], 

 [[    0.4300,    0.3850,    0.3750,    0.3500,    0.3500,    0.3550,    0.3450,    0.3200,    0.3000,    0.2850]], 

 [[    0.2650,    0.2700,    0.2750,    0.2350,    0.2200,    0.2150,    0.1300,    0.1350,    0.1650,    0.1700]], 

 [[    0.1700,    0.2200,    0.2400,    0.2350,    0.2750,    0.3000,    0.3050,    0.3300,    0.3300,    0.2850]], 

 [[    0.3000,    0.2900,    0.3050,    0.3450,    0.3450,    0.3400,    0.3200,    0.3200,    0.3350,    0.3600]], 

 [[    0.3500,    0.3600,    0.3600,    0.3500,    0.3300,    0.2750,    0.2800,    0.2900,    0.2950,    0.3150]], 

 [[    0.3400,    0.2850,    0.2750,    0.2900,    0.3150,    0.3150,    0.3200,    0.2850,    0.2800,    0.2950]], 

 [[    0.3150,    0.3050,    0.3050,    0.3100,    0.3000,    0.3200,    0.3250,    0.3100,    0.3150,    0.3150]], 

 [[    0.3100,    0.2850,    0.3050,    0.3400,    0.4050,    0.4750,    0.4350,    0.4500,    0.4300,    0.4350]], 

 [[    0.3950,    0.4100,    0.4450,    0.4450,    0.4400,    0.3950,    0.3850,    0.3200,    0.3550,    0.3050]], 

 [[    0.2800,    0.2750,    0.3100,    0.3600,    0.3750,    0.3550,    0.3700,    0.3850,    0.3900,    0.3650]], 

 [[    0.3350,    0.3400,    0.3650,    0.3750,    0.3650,    0.3700,    0.3700,    0.3500,    0.3500,    0.3000]], 

 [[    0.3000,    0.3100,    0.3000,    0.3100,    0.2950,    0.3050,    0.3150,    0.3100,    0.3000,    0.2900]], 

 [[    0.2950,    0.2900,    0.2900,    0.2950,    0.2950,    0.2950,    0.2800,    0.2250,    0.2150,    0.2200]], 

 [[    0.2100,    0.2100,    0.1700,    0.1650,    0.1550,    0.1250,    0.1700,    0.2250,    0.2750,    0.2650]], 

 [[    0.2250,    0.2200,    0.2100,    0.2250,    0.2350,    0.2350,    0.2700,    0.2700,    0.2700,    0.2650]], 

 [[    0.2550,    0.2600,    0.2550,    0.2000,    0.2050,    0.2150,    0.2350,    0.2200,    0.1900,    0.1850]], 

 [[    0.1750,    0.1750,    0.1900,    0.1700,    0.1450,    0.1450,    0.1500,    0.1700,    0.1900,    0.1800]], 

 [[    0.2000,    0.2100,    0.2000,    0.1800,    0.1650,    0.1700,    0.1450,    0.1350,    0.1250,    0.1200]], 

 [[    0.1250,    0.1300,    0.1350,    0.1350,    0.1300,    0.1200,    0.1200,    0.1100,    0.1200,    0.1000]], 

 [[    0.0850,    0.0850,    0.0800,    0.0700,    0.0750,    0.0950,    0.0950,    0.0950,    0.0950,    0.0850]], 

 [[    0.0750,    0.0750,    0.0950,    0.1100,    0.1100,    0.0900,    0.0850,    0.0700,    0.0700,    0.0650]], 

 [[    0.0400,    0.0200,    0.0200,         0,    0.0150,    0.0250,    0.0550,    0.0600,    0.0700,    0.0900]], 

 [[    0.0950,    0.1300,    0.1300,    0.1350,    0.1250,    0.1000,    0.1000,    0.0750,    0.0950,    0.1000]], 

 [[    0.1150,    0.0800,    0.0750,    0.0550,    0.0500,    0.0550,    0.0850,    0.0900,    0.0850,    0.0850]], 

 [[    0.0950,    0.1300,    0.1300,    0.1350,    0.1500,    0.1750,    0.2050,    0.2500,    0.2700,    0.2350]], 

 [[    0.2450,    0.2800,    0.2600,    0.2900,    0.2950,    0.2550,    0.2800,    0.2650,    0.2850,    0.2900]], 

 [[    0.3150,    0.3100,    0.2850,    0.2700,    0.2750,    0.2900,    0.3100,    0.3100,    0.2850,    0.2850]], 

 [[    0.2950,    0.2950,    0.2900,    0.3000,    0.3100,    0.3150,    0.3700,    0.4350,    0.5000,    0.5200]], 

 [[    0.4900,    0.5000,    0.5050,    0.4800,    0.5000,    0.5200,    0.5750,    0.6000,    0.5950,    0.5350]], 

 [[    0.5400,    0.5650,    0.5650,    0.5750,    0.6300,    0.7050,    0.7900,    0.8000,    0.7650,    0.7150]], 

 [[    0.7550,    0.8000,    0.8800,    0.8550,    0.8250,    0.7600,    0.8150,    0.8150,    0.7950,    0.6850]], 

 [[    0.7150,    0.7400,    0.7200,    0.7150,    0.8350,    0.8600,    0.8600,    0.8300,    0.9000,    0.9150]]]
===========INPUT MASK===================
[[    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000], 
 [    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000], 
 [    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000], 
 [    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000], 
 [    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000], 
 [    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000], 
 [    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000], 
 [    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000], 
 [    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000], 
 [    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000], 
 [    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000], 
 [    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000], 
 [    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000], 
 [    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000], 
 [    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000], 
 [    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000], 
 [    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000], 
 [    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000], 
 [    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000], 
 [    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000], 
 [    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000], 
 [    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000], 
 [    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000], 
 [    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000], 
 [    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000], 
 [    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000], 
 [    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000], 
 [    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000], 
 [    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000], 
 [    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000], 
 [    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000], 
 [    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000], 
 [    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000], 
 [    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000], 
 [    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000], 
 [    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000], 
 [    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000], 
 [    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000], 
 [    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000], 
 [    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000], 
 [    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000], 
 [    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000], 
 [    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000], 
 [    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000], 
 [    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000], 
 [    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000], 
 [    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000], 
 [    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000], 
 [    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000], 
 [    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000]]
===========OUTPUT MASK===================
[[    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000], 
 [    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000], 
 [    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000], 
 [    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000], 
 [    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000], 
 [    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000], 
 [    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000], 
 [    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000], 
 [    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000], 
 [    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000], 
 [    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000], 
 [    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000], 
 [    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000], 
 [    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000], 
 [    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000], 
 [    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000], 
 [    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000], 
 [    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000], 
 [    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000], 
 [    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000], 
 [    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000], 
 [    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000], 
 [    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000], 
 [    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000], 
 [    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000], 
 [    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000], 
 [    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000], 
 [    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000], 
 [    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000], 
 [    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000], 
 [    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000], 
 [    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000], 
 [    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000], 
 [    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000], 
 [    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000], 
 [    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000], 
 [    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000], 
 [    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000], 
 [    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000], 
 [    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000], 
 [    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000], 
 [    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000], 
 [    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000], 
 [    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000], 
 [    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000], 
 [    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000], 
 [    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000], 
 [    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000], 
 [    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000], 
 [    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000,    1.0000]]

LSTM Network

We define the LSTM multilayer network

In [23]:
val seed = 12345;

val periods = iteratorTrain.inputPeriods() + iteratorTrain.outcomePeriods()
val lstmLayer1Size = periods*2;
val lstmLayer2Size = periods;
val denseLayerSize = periods;
var truncatedBPTTLength = 250
val dropoutRatio = 0.8;
var nIn = iteratorTrain.inputColumns()
var nOut = iteratorTrain.totalOutcomes()

var conf = new NeuralNetConfiguration.Builder()
    .seed(seed)
    .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
    .weightInit(WeightInit.XAVIER)
    .updater(Updater.ADAGRAD)  // RMSPROP or ADAGRAD
    .l2(1e-2)
    .list()
    .layer(0, new GravesLSTM.Builder()
            .nIn(nIn)
            .nOut(lstmLayer1Size)
            .gateActivationFunction(Activation.SOFTSIGN)
            .dropOut(dropoutRatio)
            .build())
    .layer(1, new GravesLSTM.Builder()
            .nIn(lstmLayer1Size)
            .nOut(lstmLayer2Size)
            .gateActivationFunction(Activation.SOFTSIGN)
            .dropOut(dropoutRatio)
            .build())
    .layer(2, new RnnOutputLayer.Builder()
            .nIn(lstmLayer2Size)
            .nOut(nOut)
            .activation(Activation.IDENTITY)
            .lossFunction(LossFunctions.LossFunction.MSE)
            .build())
    .backpropType(BackpropType.TruncatedBPTT)
    .tBPTTForwardLength(truncatedBPTTLength)
    .tBPTTBackwardLength(truncatedBPTTLength)
    .pretrain(false)
    .backprop(true)
    .build();

var net = new MultiLayerNetwork(conf);
net.init()

conf
Out[23]:
{
  "backprop" : true,
  "backpropType" : "TruncatedBPTT",
  "cacheMode" : "NONE",
  "confs" : [ {
    "cacheMode" : "NONE",
    "epochCount" : 0,
    "iterationCount" : 0,
    "layer" : {
      "@class" : "org.deeplearning4j.nn.conf.layers.GravesLSTM",
      "activationFn" : {
        "@class" : "org.nd4j.linalg.activations.impl.ActivationSigmoid"
      },
      "biasInit" : 0.0,
      "biasUpdater" : null,
      "constraints" : null,
      "dist" : null,
      "distRecurrent" : null,
      "forgetGateBiasInit" : 1.0,
      "gateActivationFn" : {
        "@class" : "org.nd4j.linalg.activations.impl.ActivationSoftSign"
      },
      "gradientNormalization" : "None",
      "gradientNormalizationThreshold" : 1.0,
      "idropout" : {
        "@class" : "org.deeplearning4j.nn.conf.dropout.Dropout",
        "p" : 0.8,
        "pschedule" : null
      },
      "iupdater" : {
        "@class" : "org.nd4j.linalg.learning.config.AdaGrad",
        "epsilon" : 1.0E-6,
        "learningRate" : 0.1
      },
      "l1" : 0.0,
      "l1Bias" : 0.0,
      "l2" : 0.01,
      "l2Bias" : 0.0,
      "layerName" : "layer0",
      "nin" : 5,
      "nout" : 40,
      "pretrain" : false,
      "weightInit" : "XAVIER",
      "weightInitRecurrent" : null,
      "weightNoise" : null
    },
    "maxNumLineSearchIterations" : 5,
    "miniBatch" : true,
    "minimize" : true,
    "optimizationAlgo" : "STOCHASTIC_GRADIENT_DESCENT",
    "pretrain" : false,
    "seed" : 12345,
    "stepFunction" : null,
    "variables" : [ "W", "RW", "b" ]
  }, {
    "cacheMode" : "NONE",
    "epochCount" : 0,
    "iterationCount" : 0,
    "layer" : {
      "@class" : "org.deeplearning4j.nn.conf.layers.GravesLSTM",
      "activationFn" : {
        "@class" : "org.nd4j.linalg.activations.impl.ActivationSigmoid"
      },
      "biasInit" : 0.0,
      "biasUpdater" : null,
      "constraints" : null,
      "dist" : null,
      "distRecurrent" : null,
      "forgetGateBiasInit" : 1.0,
      "gateActivationFn" : {
        "@class" : "org.nd4j.linalg.activations.impl.ActivationSoftSign"
      },
      "gradientNormalization" : "None",
      "gradientNormalizationThreshold" : 1.0,
      "idropout" : {
        "@class" : "org.deeplearning4j.nn.conf.dropout.Dropout",
        "p" : 0.8,
        "pschedule" : null
      },
      "iupdater" : {
        "@class" : "org.nd4j.linalg.learning.config.AdaGrad",
        "epsilon" : 1.0E-6,
        "learningRate" : 0.1
      },
      "l1" : 0.0,
      "l1Bias" : 0.0,
      "l2" : 0.01,
      "l2Bias" : 0.0,
      "layerName" : "layer1",
      "nin" : 40,
      "nout" : 20,
      "pretrain" : false,
      "weightInit" : "XAVIER",
      "weightInitRecurrent" : null,
      "weightNoise" : null
    },
    "maxNumLineSearchIterations" : 5,
    "miniBatch" : true,
    "minimize" : true,
    "optimizationAlgo" : "STOCHASTIC_GRADIENT_DESCENT",
    "pretrain" : false,
    "seed" : 12345,
    "stepFunction" : null,
    "variables" : [ "W", "RW", "b" ]
  }, {
    "cacheMode" : "NONE",
    "epochCount" : 0,
    "iterationCount" : 0,
    "layer" : {
      "@class" : "org.deeplearning4j.nn.conf.layers.RnnOutputLayer",
      "activationFn" : {
        "@class" : "org.nd4j.linalg.activations.impl.ActivationIdentity"
      },
      "biasInit" : 0.0,
      "biasUpdater" : null,
      "constraints" : null,
      "dist" : null,
      "gradientNormalization" : "None",
      "gradientNormalizationThreshold" : 1.0,
      "hasBias" : true,
      "idropout" : null,
      "iupdater" : {
        "@class" : "org.nd4j.linalg.learning.config.AdaGrad",
        "epsilon" : 1.0E-6,
        "learningRate" : 0.1
      },
      "l1" : 0.0,
      "l1Bias" : 0.0,
      "l2" : 0.01,
      "l2Bias" : 0.0,
      "layerName" : "layer2",
      "lossFn" : {
        "@class" : "org.nd4j.linalg.lossfunctions.impl.LossMSE",
        "configProperties" : false,
        "numOutputs" : -1
      },
      "nin" : 20,
      "nout" : 1,
      "pretrain" : false,
      "weightInit" : "XAVIER",
      "weightNoise" : null
    },
    "maxNumLineSearchIterations" : 5,
    "miniBatch" : true,
    "minimize" : true,
    "optimizationAlgo" : "STOCHASTIC_GRADIENT_DESCENT",
    "pretrain" : false,
    "seed" : 12345,
    "stepFunction" : null,
    "variables" : [ "W", "b" ]
  } ],
  "epochCount" : 0,
  "inferenceWorkspaceMode" : "ENABLED",
  "inputPreProcessors" : { },
  "iterationCount" : 0,
  "pretrain" : false,
  "tbpttBackLength" : 250,
  "tbpttFwdLength" : 250,
  "trainingWorkspaceMode" : "ENABLED"
}

Fitting the Network

Next we train the network in with n epochs:

In [24]:
println("Training: ")
var nEpochs = 50
var client = NamespaceClient.getBeakerX()
for(i <- 0  to nEpochs-1 ) {
    iteratorTrain.reset(); 
    while (iteratorTrain.hasNext()) {
        var data = iteratorTrain.next()
        var idx = iteratorTrain.currentIndex()
        var maxIdx = iteratorTrain.maxIndex()
        client.showProgressUpdate("", ((i * maxIdx + idx) * 100) / (nEpochs * maxIdx) )
        net.fit(data); 
    }
}

"Done"
Training: 
Out[24]:
Done

Evaluation

We create the test data in order to evaluate the predictions with the IndicatorSplitter. We use the data starting from 2018-01-01.

In [25]:
import scala.collection.JavaConverters._

var splitDate = Context.date("2018-01-01")
var inputTest = IndicatorSplitter.split(in.toList.asJava.asInstanceOf[java.util.List[Indicator[Num]]], splitDate, false)
var outputTest = IndicatorSplitter.split(out.toList.asJava.asInstanceOf[java.util.List[Indicator[Num]]], splitDate, false)
var iteratorTest = new StockData3DIterator(inputTrain, outputTrain, 50, 10, 10, 10); 
iteratorTest.setScalingPerDataset(true)
Out[25]:
null

Then we determine forecasted values and compare them with the actual values: In each batch we need to take the all values where the mask returns a 1.0 value:

In [26]:
var resultMap = new ListBuffer[Map[String,Any]] 
var ev = new RegressionEvaluation(1)
iteratorTest.reset();
net.rnnClearPreviousState()

while (iteratorTest.hasNext()) {
    var data = iteratorTest.next().asInstanceOf[ScalingDataSet]
    var mask = data.getLabelsMaskArray() 
    var labels = data.getLabels()
    var prediction = net.rnnTimeStep(data.getFeatures());
    
    ev.eval(labels, prediction)
    for (j <- 0 to iteratorTest.inputPeriods()-1) {
        if (mask.getDouble(0l,j) == 1.0) {    
            resultMap += scala.collection.Map("actual" -> labels.getDouble(0l,0l, j), "predict" -> prediction.getDouble(0l,0l,j))
        }
    }
}

ev.stats
Out[26]:
Column    MSE            MAE            RMSE           RSE            PC             R^2            
col_0     1.26151e-02    7.75928e-02    1.12317e-01    1.62705e-01    9.19468e-01    8.37295e-01    

Here is the (scaled) result as table:

In [27]:
import scala.collection.JavaConverters._

resultMap.map(r => r.asJava).asJava

... and the result as chart:

In [28]:
val actualLine = new Line() {
    x = 1 to resultMap.size
    y = resultMap.map(map => map.get("actual").get.asInstanceOf[Double])
    displayName = "actual"
}
val predictLine = new Line() {
    x = 1 to resultMap.size
    y = resultMap.map(map => map.get("predict").get.asInstanceOf[Double])
    displayName = "predict"
}

new Plot().add(Seq(actualLine, predictLine))

Full Data

Finally here is the unscaled result over the full data. We use the revertLabels method to revert the scaling

In [34]:
import scala.collection.JavaConverters._

var resultList = new ListBuffer[Double] 
var iterator = new StockData3DIterator(in.asJava, out.asJava, 1, 10, 10, 10); 
iterator.setScalingPerDataset(true)

net.rnnClearPreviousState()

while (iterator.hasNext()) {
    var data = iterator.next().asInstanceOf[ScalingDataSet]
    var mask = data.getLabelsMaskArray() 
    var labels = data.getLabels()
    data.revertLabels(labels, mask)
    
    var prediction = net.rnnTimeStep(data.getFeatures());    
    data.revertLabels(prediction, mask)
    
    ev.eval(labels, prediction)
    for (j <- 0 to iteratorTest.inputPeriods()-1) {
        if (mask.getDouble(0l,j) == 1.0) {    
            resultList +=  prediction.getDouble(0l,0l,j)
        }
    }
}
Out[34]:
null

Here is the result as Chart

In [35]:
var actualValues = HistoricValues.create(close)
val actualLine = new Line() {
    x = 1 to actualValues.size()
    y = actualValues.getValues().asScala
    displayName = "actual"
}

val predictLine = new Line() {
    x = 1 to resultList.size
    y = resultList
    displayName = "predictLine"
}

new Plot().add(Seq(actualLine, predictLine))

LSTMForecast

The functionality which has been described in detail above can be executed with only a few lines of code with the help of the LSTMForecast class:

In [36]:
import scala.collection.JavaConverters._

var iterator = new StockData3DIterator(in.asJava, out.asJava, 1, 10, 10, 10); 
iterator.setScalingPerDataset(true)
var epochs = 10
var forecast = new LSTMForecast(iterator, epochs);
var onlyForecast = new LSTMForecast(iterator, forecast.getNet());
onlyForecast.setName("OnlyForecast")
onlyForecast.setOnlyPredictions(true)

var tableFromIndicators = Table.create(new ForecastIndicator(forecast,1),new ForecastIndicator(onlyForecast,1))
In [32]:
new SimpleTimePlot {
    data = tableFromIndicators.seq()
    columns = Seq("LSTMForecast","OnlyForecast")
}
In [ ]: