TensorFlow Scala Example

First we have to load TensorFlowScala into our kernel.

In [1]:
interp.load.ivy(
  // replace with linux-gpu-x86_64 on linux with nvidia gpu or with darwin-cpu-x86_64 on macOS 
  ("org.platanios" %% "tensorflow" % "0.4.1").withClassifier("linux-cpu-x86_64"),
  "org.platanios" %% "tensorflow-data" % "0.4.1"
)
In [2]:
import org.platanios.tensorflow.api._
import org.platanios.tensorflow.api.learn._
import org.platanios.tensorflow.api.learn.layers._
import org.platanios.tensorflow.api.learn.estimators.InMemoryEstimator
import org.platanios.tensorflow.data.image.MNISTLoader

import java.nio.file.Paths
import scala.util.Random
Out[2]:
import org.platanios.tensorflow.api._

import org.platanios.tensorflow.api.learn._

import org.platanios.tensorflow.api.learn.layers._

import org.platanios.tensorflow.api.learn.estimators.InMemoryEstimator

import org.platanios.tensorflow.data.image.MNISTLoader


import java.nio.file.Paths

import scala.util.Random
In [3]:
val dataset = MNISTLoader.load(Paths.get("tmp/mnist"))
2018-11-25 22:11:45.197 [scala-interpreter-1] INFO  MNIST Data Loader - Downloading file 'http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz'.
2018-11-25 22:11:48.007 [scala-interpreter-1] INFO  MNIST Data Loader - Downloaded file 'http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz'.
2018-11-25 22:11:48.008 [scala-interpreter-1] INFO  MNIST Data Loader - Downloading file 'http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz'.
2018-11-25 22:11:48.114 [scala-interpreter-1] INFO  MNIST Data Loader - Downloaded file 'http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz'.
2018-11-25 22:11:48.114 [scala-interpreter-1] INFO  MNIST Data Loader - Downloading file 'http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz'.
2018-11-25 22:11:48.550 [scala-interpreter-1] INFO  MNIST Data Loader - Downloaded file 'http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz'.
2018-11-25 22:11:48.552 [scala-interpreter-1] INFO  MNIST Data Loader - Downloading file 'http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz'.
2018-11-25 22:11:48.654 [scala-interpreter-1] INFO  MNIST Data Loader - Downloaded file 'http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz'.
2018-11-25 22:11:48.655 [scala-interpreter-1] INFO  MNIST Data Loader - Extracting images from file 'tmp/mnist/train-images-idx3-ubyte.gz'.
2018-11-25 22:11:50.320 [scala-interpreter-1] INFO  MNIST Data Loader - Extracting labels from file 'tmp/mnist/train-labels-idx1-ubyte.gz'.
2018-11-25 22:11:50.323 [scala-interpreter-1] INFO  MNIST Data Loader - Extracting images from file 'tmp/mnist/t10k-images-idx3-ubyte.gz'.
2018-11-25 22:11:50.392 [scala-interpreter-1] INFO  MNIST Data Loader - Extracting labels from file 'tmp/mnist/t10k-labels-idx1-ubyte.gz'.
2018-11-25 22:11:50.393 [scala-interpreter-1] INFO  MNIST Data Loader - Finished loading the MNIST dataset.
Out[3]:
dataset: org.platanios.tensorflow.data.image.MNISTDataset = MNISTDataset(
  MNIST,
  Tensor[UByte, [60000, 28, 28]],
  Tensor[UByte, [60000]],
  Tensor[UByte, [10000, 28, 28]],
  Tensor[UByte, [10000]]
)

Let's display a few example images from our training set.

In [4]:
val sess = Session()
def showImage(image: Tensor[UByte]): Unit = {
  val exampleImage = tf.decodeRaw[Byte](tf.image.encodePng(image))
  val png = sess.run(fetches = exampleImage)
  Image(png.entriesIterator.toArray).withFormat(Image.PNG).withWidth(50).display()
}

val randomImages = for {
  index <- (0 to 4).map(_ => Random.nextInt(dataset.testImages.shape(0)))
  image = dataset.testImages(index).expandDims(-1)
} yield image
randomImages.foreach(showImage(_))
Out[4]:
sess: Session = [email protected]
defined function showImage
randomImages: collection.immutable.IndexedSeq[Tensor[UByte]] = Vector(
  Tensor[UByte, [28, 28, 1]],
  Tensor[UByte, [28, 28, 1]],
  Tensor[UByte, [28, 28, 1]],
  Tensor[UByte, [28, 28, 1]],
  Tensor[UByte, [28, 28, 1]]
)
In [5]:
// Load and batch data using pre-fetching.
val trainImages = tf.data.datasetFromTensorSlices(dataset.trainImages.toFloat)
val trainLabels = tf.data.datasetFromTensorSlices(dataset.trainLabels.toLong)
val trainData =
  trainImages.zip(trainLabels)
      .repeat()
      .shuffle(10000)
      .batch(256)
      .prefetch(10)
Out[5]:
trainImages: ops.data.Dataset[Output[Float]] = Dataset[TensorSlicesDataset]
trainLabels: ops.data.Dataset[Output[Long]] = Dataset[TensorSlicesDataset]
trainData: ops.data.Dataset[(Output[Float], Output[Long])] = Dataset[TensorSlicesDataset/Zip/Repeat/Shuffle/Batch/Prefetch]
In [6]:
// Create the MLP model.
val input = Input(FLOAT32, Shape(-1, 28, 28))
val trainInput = Input(INT64, Shape(-1))
val layer = Flatten[Float]("Input/Flatten") >>
    Linear[Float]("Layer_0", 128) >> ReLU[Float]("Layer_0/Activation", 0.1f) >>
    Linear[Float]("Layer_1", 64) >> ReLU[Float]("Layer_1/Activation", 0.1f) >>
    Linear[Float]("Layer_2", 32) >> ReLU[Float]("Layer_2/Activation", 0.1f) >>
    Linear[Float]("OutputLayer", 10)
val loss = SparseSoftmaxCrossEntropy[Float, Long, Float]("Loss") >>
    Mean("Loss/Mean")
val optimizer = tf.train.GradientDescent(1e-6f)
val model = Model.simpleSupervised(input, trainInput, layer, loss, optimizer)
Out[6]:
input: Input[Output[Float]] = [email protected]
trainInput: Input[Output[Long]] = [email protected]
layer: Compose[Output[Float], Output[Float], Output[Float]] = Compose(
  "Input/Flatten",
  Compose(
    "Input/Flatten",
    Compose(
      "Input/Flatten",
      Compose(
        "Input/Flatten",
        Compose(
          "Input/Flatten",
          Compose(
            "Input/Flatten",
            Compose(
              "Input/Flatten",
              Flatten("Input/Flatten"),
              Linear(
                "Layer_0",
                128,
                true,
                RandomNormalInitializer(
                  Tensor[Float, []],
                  Tensor[Float, []],
                  None
                ),
                RandomNormalInitializer(
                  Tensor[Float, []],
                  Tensor[Float, []],
                  None
                )
              )
            ),
            ReLU("Layer_0/Activation", 0.1F)
          ),
          Linear(
            "Layer_1",
            64,
            true,
            RandomNormalInitializer(Tensor[Float, []], Tensor[Float, []], None),
            RandomNormalInitializer(Tensor[Float, []], Tensor[Float, []], None)
...
loss: Compose[(Output[Float], Output[Long]), Output[Float], Output[Float]] = Compose(
  "Loss",
  SparseSoftmaxCrossEntropy("Loss"),
  Mean("Loss/Mean", null, false)
)
optimizer: ops.training.optimizers.GradientDescent = org.p[email protected]6aca6d9f
model: SupervisedTrainableModel[Output[Float], Output[Long], Output[Float], Output[Float], Float] = [email protected]
In [7]:
import org.platanios.tensorflow.api.learn.hooks._
import org.platanios.tensorflow.api.config.TensorBoardConfig

val loss = SparseSoftmaxCrossEntropy[Float, Long, Float]("Loss") >>
    Mean("Loss/Mean") >>
    ScalarSummary(name = "Loss", tag = "Loss")
val summariesDir = Paths.get("tmp/summaries")
val estimator = InMemoryEstimator(
  modelFunction = model,
  configurationBase = Configuration(Some(summariesDir)),
  trainHooks = Set(
    SummarySaver(summariesDir, StepHookTrigger(100)),
    CheckpointSaver(summariesDir, StepHookTrigger(1000))),
  tensorBoardConfig = TensorBoardConfig(summariesDir))
Out[7]:
import org.platanios.tensorflow.api.learn.hooks._

import org.platanios.tensorflow.api.config.TensorBoardConfig


loss: Compose[(Output[Float], Output[Long]), Output[Float], Output[Float]] = Compose(
  "Loss",
  Compose(
    "Loss",
    SparseSoftmaxCrossEntropy("Loss"),
    Mean("Loss/Mean", null, false)
  ),
  ScalarSummary(
    "Loss",
    "Loss",
    null,
    Set([email protected])
  )
)
summariesDir: java.nio.file.Path = tmp/summaries
estimator: InMemoryEstimator[Output[Float], (Output[Float], Output[Long]), Output[Float], Output[Float], Float, (Output[Float], (Output[Float], Output[Long]))] = [email protected]2f529710
In [8]:
estimator.train(() => trainData, StopCriteria(maxSteps = Some(10000)))
2018-11-25 22:12:19.315 [scala-interpreter-1] INFO  Learn / Hooks / Checkpoint Saver - Saving checkpoint for step 0.
2018-11-25 22:12:19.317 [scala-interpreter-1] INFO  Variables / Saver - Saving parameters to '/Users/brunksn/repos/my/almond-examples/tmp/summaries/model.ckpt-0'.
2018-11-25 22:12:24.316 [scala-interpreter-1] INFO  Variables / Saver - Saved parameters to '/Users/brunksn/repos/my/almond-examples/tmp/summaries/model.ckpt-0'.
2018-11-25 22:12:27.875 [scala-interpreter-1] INFO  Learn / Hooks / Checkpoint Saver - Saving checkpoint for step 1000.
2018-11-25 22:12:27.876 [scala-interpreter-1] INFO  Variables / Saver - Saving parameters to '/Users/brunksn/repos/my/almond-examples/tmp/summaries/model.ckpt-1000'.
2018-11-25 22:12:28.805 [scala-interpreter-1] INFO  Variables / Saver - Saved parameters to '/Users/brunksn/repos/my/almond-examples/tmp/summaries/model.ckpt-1000'.
2018-11-25 22:12:33.241 [scala-interpreter-1] INFO  Learn / Hooks / Checkpoint Saver - Saving checkpoint for step 2000.
2018-11-25 22:12:33.242 [scala-interpreter-1] INFO  Variables / Saver - Saving parameters to '/Users/brunksn/repos/my/almond-examples/tmp/summaries/model.ckpt-2000'.
2018-11-25 22:12:34.679 [scala-interpreter-1] INFO  Variables / Saver - Saved parameters to '/Users/brunksn/repos/my/almond-examples/tmp/summaries/model.ckpt-2000'.
2018-11-25 22:12:38.452 [scala-interpreter-1] INFO  Learn / Hooks / Checkpoint Saver - Saving checkpoint for step 3000.
2018-11-25 22:12:38.452 [scala-interpreter-1] INFO  Variables / Saver - Saving parameters to '/Users/brunksn/repos/my/almond-examples/tmp/summaries/model.ckpt-3000'.
2018-11-25 22:12:39.675 [scala-interpreter-1] INFO  Variables / Saver - Saved parameters to '/Users/brunksn/repos/my/almond-examples/tmp/summaries/model.ckpt-3000'.
2018-11-25 22:12:45.426 [scala-interpreter-1] INFO  Learn / Hooks / Checkpoint Saver - Saving checkpoint for step 4000.
2018-11-25 22:12:45.427 [scala-interpreter-1] INFO  Variables / Saver - Saving parameters to '/Users/brunksn/repos/my/almond-examples/tmp/summaries/model.ckpt-4000'.
2018-11-25 22:12:46.311 [scala-interpreter-1] INFO  Variables / Saver - Saved parameters to '/Users/brunksn/repos/my/almond-examples/tmp/summaries/model.ckpt-4000'.
2018-11-25 22:12:50.535 [scala-interpreter-1] INFO  Learn / Hooks / Checkpoint Saver - Saving checkpoint for step 5000.
2018-11-25 22:12:50.536 [scala-interpreter-1] INFO  Variables / Saver - Saving parameters to '/Users/brunksn/repos/my/almond-examples/tmp/summaries/model.ckpt-5000'.
2018-11-25 22:12:51.376 [scala-interpreter-1] INFO  Variables / Saver - Saved parameters to '/Users/brunksn/repos/my/almond-examples/tmp/summaries/model.ckpt-5000'.
2018-11-25 22:12:54.788 [scala-interpreter-1] INFO  Learn / Hooks / Checkpoint Saver - Saving checkpoint for step 6000.
2018-11-25 22:12:54.789 [scala-interpreter-1] INFO  Variables / Saver - Saving parameters to '/Users/brunksn/repos/my/almond-examples/tmp/summaries/model.ckpt-6000'.
2018-11-25 22:12:55.616 [scala-interpreter-1] INFO  Variables / Saver - Saved parameters to '/Users/brunksn/repos/my/almond-examples/tmp/summaries/model.ckpt-6000'.
2018-11-25 22:13:00.133 [scala-interpreter-1] INFO  Learn / Hooks / Checkpoint Saver - Saving checkpoint for step 7000.
2018-11-25 22:13:00.136 [scala-interpreter-1] INFO  Variables / Saver - Saving parameters to '/Users/brunksn/repos/my/almond-examples/tmp/summaries/model.ckpt-7000'.
2018-11-25 22:13:01.684 [scala-interpreter-1] INFO  Variables / Saver - Saved parameters to '/Users/brunksn/repos/my/almond-examples/tmp/summaries/model.ckpt-7000'.
2018-11-25 22:13:05.769 [scala-interpreter-1] INFO  Learn / Hooks / Checkpoint Saver - Saving checkpoint for step 8000.
2018-11-25 22:13:05.770 [scala-interpreter-1] INFO  Variables / Saver - Saving parameters to '/Users/brunksn/repos/my/almond-examples/tmp/summaries/model.ckpt-8000'.
2018-11-25 22:13:06.706 [scala-interpreter-1] INFO  Variables / Saver - Saved parameters to '/Users/brunksn/repos/my/almond-examples/tmp/summaries/model.ckpt-8000'.
2018-11-25 22:13:09.755 [scala-interpreter-1] INFO  Learn / Hooks / Checkpoint Saver - Saving checkpoint for step 9000.
2018-11-25 22:13:09.756 [scala-interpreter-1] INFO  Variables / Saver - Saving parameters to '/Users/brunksn/repos/my/almond-examples/tmp/summaries/model.ckpt-9000'.
2018-11-25 22:13:10.888 [scala-interpreter-1] INFO  Variables / Saver - Saved parameters to '/Users/brunksn/repos/my/almond-examples/tmp/summaries/model.ckpt-9000'.
2018-11-25 22:13:13.829 [scala-interpreter-1] INFO  Learn / Hooks / Checkpoint Saver - Saving checkpoint for step 10000.
2018-11-25 22:13:13.829 [scala-interpreter-1] INFO  Variables / Saver - Saving parameters to '/Users/brunksn/repos/my/almond-examples/tmp/summaries/model.ckpt-10000'.
2018-11-25 22:13:14.605 [scala-interpreter-1] INFO  Variables / Saver - Saved parameters to '/Users/brunksn/repos/my/almond-examples/tmp/summaries/model.ckpt-10000'.
In [9]:
val images = dataset.testImages(0::10)
val result = estimator.infer(() => images.toFloat)
val predictedClasses = result.argmax(1)
for (i <- 0 until images.shape(0)) {
    showImage(images(i).expandDims(-1))
    print("predicted class: ")
    print(predictedClasses(i).scalar)
}
predicted class: 7
predicted class: 2
predicted class: 1
predicted class: 0
predicted class: 9
predicted class: 1
predicted class: 9
predicted class: 9
predicted class: 2
predicted class: 9
Out[9]:
images: Tensor[UByte] = Tensor[UByte, [10, 28, 28]]
result: Tensor[Float] = Tensor[Float, [10, 10]]
predictedClasses: Tensor[Long] = Tensor[Long, [10]]