First we have to load TensorFlowScala into our kernel.
interp.load.ivy(
// replace with linux-gpu-x86_64 on linux with nvidia gpu or with darwin-cpu-x86_64 on macOS
("org.platanios" %% "tensorflow" % "0.4.1").withClassifier("linux-cpu-x86_64"),
"org.platanios" %% "tensorflow-data" % "0.4.1"
)
import org.platanios.tensorflow.api._
import org.platanios.tensorflow.api.learn._
import org.platanios.tensorflow.api.learn.layers._
import org.platanios.tensorflow.api.learn.estimators.InMemoryEstimator
import org.platanios.tensorflow.data.image.MNISTLoader
import java.nio.file.Paths
import scala.util.Random
import org.platanios.tensorflow.api._ import org.platanios.tensorflow.api.learn._ import org.platanios.tensorflow.api.learn.layers._ import org.platanios.tensorflow.api.learn.estimators.InMemoryEstimator import org.platanios.tensorflow.data.image.MNISTLoader import java.nio.file.Paths import scala.util.Random
val dataset = MNISTLoader.load(Paths.get("tmp/mnist"))
2018-11-25 22:11:45.197 [scala-interpreter-1] INFO MNIST Data Loader - Downloading file 'http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz'. 2018-11-25 22:11:48.007 [scala-interpreter-1] INFO MNIST Data Loader - Downloaded file 'http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz'. 2018-11-25 22:11:48.008 [scala-interpreter-1] INFO MNIST Data Loader - Downloading file 'http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz'. 2018-11-25 22:11:48.114 [scala-interpreter-1] INFO MNIST Data Loader - Downloaded file 'http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz'. 2018-11-25 22:11:48.114 [scala-interpreter-1] INFO MNIST Data Loader - Downloading file 'http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz'. 2018-11-25 22:11:48.550 [scala-interpreter-1] INFO MNIST Data Loader - Downloaded file 'http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz'. 2018-11-25 22:11:48.552 [scala-interpreter-1] INFO MNIST Data Loader - Downloading file 'http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz'. 2018-11-25 22:11:48.654 [scala-interpreter-1] INFO MNIST Data Loader - Downloaded file 'http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz'. 2018-11-25 22:11:48.655 [scala-interpreter-1] INFO MNIST Data Loader - Extracting images from file 'tmp/mnist/train-images-idx3-ubyte.gz'. 2018-11-25 22:11:50.320 [scala-interpreter-1] INFO MNIST Data Loader - Extracting labels from file 'tmp/mnist/train-labels-idx1-ubyte.gz'. 2018-11-25 22:11:50.323 [scala-interpreter-1] INFO MNIST Data Loader - Extracting images from file 'tmp/mnist/t10k-images-idx3-ubyte.gz'. 2018-11-25 22:11:50.392 [scala-interpreter-1] INFO MNIST Data Loader - Extracting labels from file 'tmp/mnist/t10k-labels-idx1-ubyte.gz'. 2018-11-25 22:11:50.393 [scala-interpreter-1] INFO MNIST Data Loader - Finished loading the MNIST dataset.
dataset: org.platanios.tensorflow.data.image.MNISTDataset = MNISTDataset( MNIST, Tensor[UByte, [60000, 28, 28]], Tensor[UByte, [60000]], Tensor[UByte, [10000, 28, 28]], Tensor[UByte, [10000]] )
Let's display a few example images from our training set.
val sess = Session()
def showImage(image: Tensor[UByte]): Unit = {
val exampleImage = tf.decodeRaw[Byte](tf.image.encodePng(image))
val png = sess.run(fetches = exampleImage)
Image(png.entriesIterator.toArray).withFormat(Image.PNG).withWidth(50).display()
}
val randomImages = for {
index <- (0 to 4).map(_ => Random.nextInt(dataset.testImages.shape(0)))
image = dataset.testImages(index).expandDims(-1)
} yield image
randomImages.foreach(showImage(_))
sess: Session = org.platanios.tensorflow.api.core.client.Session@481ca54d defined function showImage randomImages: collection.immutable.IndexedSeq[Tensor[UByte]] = Vector( Tensor[UByte, [28, 28, 1]], Tensor[UByte, [28, 28, 1]], Tensor[UByte, [28, 28, 1]], Tensor[UByte, [28, 28, 1]], Tensor[UByte, [28, 28, 1]] )
// Load and batch data using pre-fetching.
val trainImages = tf.data.datasetFromTensorSlices(dataset.trainImages.toFloat)
val trainLabels = tf.data.datasetFromTensorSlices(dataset.trainLabels.toLong)
val trainData =
trainImages.zip(trainLabels)
.repeat()
.shuffle(10000)
.batch(256)
.prefetch(10)
trainImages: ops.data.Dataset[Output[Float]] = Dataset[TensorSlicesDataset] trainLabels: ops.data.Dataset[Output[Long]] = Dataset[TensorSlicesDataset] trainData: ops.data.Dataset[(Output[Float], Output[Long])] = Dataset[TensorSlicesDataset/Zip/Repeat/Shuffle/Batch/Prefetch]
// Create the MLP model.
val input = Input(FLOAT32, Shape(-1, 28, 28))
val trainInput = Input(INT64, Shape(-1))
val layer = Flatten[Float]("Input/Flatten") >>
Linear[Float]("Layer_0", 128) >> ReLU[Float]("Layer_0/Activation", 0.1f) >>
Linear[Float]("Layer_1", 64) >> ReLU[Float]("Layer_1/Activation", 0.1f) >>
Linear[Float]("Layer_2", 32) >> ReLU[Float]("Layer_2/Activation", 0.1f) >>
Linear[Float]("OutputLayer", 10)
val loss = SparseSoftmaxCrossEntropy[Float, Long, Float]("Loss") >>
Mean("Loss/Mean")
val optimizer = tf.train.GradientDescent(1e-6f)
val model = Model.simpleSupervised(input, trainInput, layer, loss, optimizer)
input: Input[Output[Float]] = org.platanios.tensorflow.api.learn.layers.Input@24eac06 trainInput: Input[Output[Long]] = org.platanios.tensorflow.api.learn.layers.Input@7dec458b layer: Compose[Output[Float], Output[Float], Output[Float]] = Compose( "Input/Flatten", Compose( "Input/Flatten", Compose( "Input/Flatten", Compose( "Input/Flatten", Compose( "Input/Flatten", Compose( "Input/Flatten", Compose( "Input/Flatten", Flatten("Input/Flatten"), Linear( "Layer_0", 128, true, RandomNormalInitializer( Tensor[Float, []], Tensor[Float, []], None ), RandomNormalInitializer( Tensor[Float, []], Tensor[Float, []], None ) ) ), ReLU("Layer_0/Activation", 0.1F) ), Linear( "Layer_1", 64, true, RandomNormalInitializer(Tensor[Float, []], Tensor[Float, []], None), RandomNormalInitializer(Tensor[Float, []], Tensor[Float, []], None) ... loss: Compose[(Output[Float], Output[Long]), Output[Float], Output[Float]] = Compose( "Loss", SparseSoftmaxCrossEntropy("Loss"), Mean("Loss/Mean", null, false) ) optimizer: ops.training.optimizers.GradientDescent = org.platanios.tensorflow.api.ops.training.optimizers.GradientDescent@6aca6d9f model: SupervisedTrainableModel[Output[Float], Output[Long], Output[Float], Output[Float], Float] = org.platanios.tensorflow.api.learn.Model$$anon$1@fe85c89
import org.platanios.tensorflow.api.learn.hooks._
import org.platanios.tensorflow.api.config.TensorBoardConfig
val loss = SparseSoftmaxCrossEntropy[Float, Long, Float]("Loss") >>
Mean("Loss/Mean") >>
ScalarSummary(name = "Loss", tag = "Loss")
val summariesDir = Paths.get("tmp/summaries")
val estimator = InMemoryEstimator(
modelFunction = model,
configurationBase = Configuration(Some(summariesDir)),
trainHooks = Set(
SummarySaver(summariesDir, StepHookTrigger(100)),
CheckpointSaver(summariesDir, StepHookTrigger(1000))),
tensorBoardConfig = TensorBoardConfig(summariesDir))
import org.platanios.tensorflow.api.learn.hooks._ import org.platanios.tensorflow.api.config.TensorBoardConfig loss: Compose[(Output[Float], Output[Long]), Output[Float], Output[Float]] = Compose( "Loss", Compose( "Loss", SparseSoftmaxCrossEntropy("Loss"), Mean("Loss/Mean", null, false) ), ScalarSummary( "Loss", "Loss", null, Set(org.platanios.tensorflow.api.core.Graph$Keys$SUMMARIES$@5ed9cf9e) ) ) summariesDir: java.nio.file.Path = tmp/summaries estimator: InMemoryEstimator[Output[Float], (Output[Float], Output[Long]), Output[Float], Output[Float], Float, (Output[Float], (Output[Float], Output[Long]))] = org.platanios.tensorflow.api.learn.estimators.InMemoryEstimator@2f529710
estimator.train(() => trainData, StopCriteria(maxSteps = Some(10000)))
2018-11-25 22:12:19.315 [scala-interpreter-1] INFO Learn / Hooks / Checkpoint Saver - Saving checkpoint for step 0. 2018-11-25 22:12:19.317 [scala-interpreter-1] INFO Variables / Saver - Saving parameters to '/Users/brunksn/repos/my/almond-examples/tmp/summaries/model.ckpt-0'. 2018-11-25 22:12:24.316 [scala-interpreter-1] INFO Variables / Saver - Saved parameters to '/Users/brunksn/repos/my/almond-examples/tmp/summaries/model.ckpt-0'. 2018-11-25 22:12:27.875 [scala-interpreter-1] INFO Learn / Hooks / Checkpoint Saver - Saving checkpoint for step 1000. 2018-11-25 22:12:27.876 [scala-interpreter-1] INFO Variables / Saver - Saving parameters to '/Users/brunksn/repos/my/almond-examples/tmp/summaries/model.ckpt-1000'. 2018-11-25 22:12:28.805 [scala-interpreter-1] INFO Variables / Saver - Saved parameters to '/Users/brunksn/repos/my/almond-examples/tmp/summaries/model.ckpt-1000'. 2018-11-25 22:12:33.241 [scala-interpreter-1] INFO Learn / Hooks / Checkpoint Saver - Saving checkpoint for step 2000. 2018-11-25 22:12:33.242 [scala-interpreter-1] INFO Variables / Saver - Saving parameters to '/Users/brunksn/repos/my/almond-examples/tmp/summaries/model.ckpt-2000'. 2018-11-25 22:12:34.679 [scala-interpreter-1] INFO Variables / Saver - Saved parameters to '/Users/brunksn/repos/my/almond-examples/tmp/summaries/model.ckpt-2000'. 2018-11-25 22:12:38.452 [scala-interpreter-1] INFO Learn / Hooks / Checkpoint Saver - Saving checkpoint for step 3000. 2018-11-25 22:12:38.452 [scala-interpreter-1] INFO Variables / Saver - Saving parameters to '/Users/brunksn/repos/my/almond-examples/tmp/summaries/model.ckpt-3000'. 2018-11-25 22:12:39.675 [scala-interpreter-1] INFO Variables / Saver - Saved parameters to '/Users/brunksn/repos/my/almond-examples/tmp/summaries/model.ckpt-3000'. 2018-11-25 22:12:45.426 [scala-interpreter-1] INFO Learn / Hooks / Checkpoint Saver - Saving checkpoint for step 4000. 2018-11-25 22:12:45.427 [scala-interpreter-1] INFO Variables / Saver - Saving parameters to '/Users/brunksn/repos/my/almond-examples/tmp/summaries/model.ckpt-4000'. 2018-11-25 22:12:46.311 [scala-interpreter-1] INFO Variables / Saver - Saved parameters to '/Users/brunksn/repos/my/almond-examples/tmp/summaries/model.ckpt-4000'. 2018-11-25 22:12:50.535 [scala-interpreter-1] INFO Learn / Hooks / Checkpoint Saver - Saving checkpoint for step 5000. 2018-11-25 22:12:50.536 [scala-interpreter-1] INFO Variables / Saver - Saving parameters to '/Users/brunksn/repos/my/almond-examples/tmp/summaries/model.ckpt-5000'. 2018-11-25 22:12:51.376 [scala-interpreter-1] INFO Variables / Saver - Saved parameters to '/Users/brunksn/repos/my/almond-examples/tmp/summaries/model.ckpt-5000'. 2018-11-25 22:12:54.788 [scala-interpreter-1] INFO Learn / Hooks / Checkpoint Saver - Saving checkpoint for step 6000. 2018-11-25 22:12:54.789 [scala-interpreter-1] INFO Variables / Saver - Saving parameters to '/Users/brunksn/repos/my/almond-examples/tmp/summaries/model.ckpt-6000'. 2018-11-25 22:12:55.616 [scala-interpreter-1] INFO Variables / Saver - Saved parameters to '/Users/brunksn/repos/my/almond-examples/tmp/summaries/model.ckpt-6000'. 2018-11-25 22:13:00.133 [scala-interpreter-1] INFO Learn / Hooks / Checkpoint Saver - Saving checkpoint for step 7000. 2018-11-25 22:13:00.136 [scala-interpreter-1] INFO Variables / Saver - Saving parameters to '/Users/brunksn/repos/my/almond-examples/tmp/summaries/model.ckpt-7000'. 2018-11-25 22:13:01.684 [scala-interpreter-1] INFO Variables / Saver - Saved parameters to '/Users/brunksn/repos/my/almond-examples/tmp/summaries/model.ckpt-7000'. 2018-11-25 22:13:05.769 [scala-interpreter-1] INFO Learn / Hooks / Checkpoint Saver - Saving checkpoint for step 8000. 2018-11-25 22:13:05.770 [scala-interpreter-1] INFO Variables / Saver - Saving parameters to '/Users/brunksn/repos/my/almond-examples/tmp/summaries/model.ckpt-8000'. 2018-11-25 22:13:06.706 [scala-interpreter-1] INFO Variables / Saver - Saved parameters to '/Users/brunksn/repos/my/almond-examples/tmp/summaries/model.ckpt-8000'. 2018-11-25 22:13:09.755 [scala-interpreter-1] INFO Learn / Hooks / Checkpoint Saver - Saving checkpoint for step 9000. 2018-11-25 22:13:09.756 [scala-interpreter-1] INFO Variables / Saver - Saving parameters to '/Users/brunksn/repos/my/almond-examples/tmp/summaries/model.ckpt-9000'. 2018-11-25 22:13:10.888 [scala-interpreter-1] INFO Variables / Saver - Saved parameters to '/Users/brunksn/repos/my/almond-examples/tmp/summaries/model.ckpt-9000'. 2018-11-25 22:13:13.829 [scala-interpreter-1] INFO Learn / Hooks / Checkpoint Saver - Saving checkpoint for step 10000. 2018-11-25 22:13:13.829 [scala-interpreter-1] INFO Variables / Saver - Saving parameters to '/Users/brunksn/repos/my/almond-examples/tmp/summaries/model.ckpt-10000'. 2018-11-25 22:13:14.605 [scala-interpreter-1] INFO Variables / Saver - Saved parameters to '/Users/brunksn/repos/my/almond-examples/tmp/summaries/model.ckpt-10000'.
val images = dataset.testImages(0::10)
val result = estimator.infer(() => images.toFloat)
val predictedClasses = result.argmax(1)
for (i <- 0 until images.shape(0)) {
showImage(images(i).expandDims(-1))
print("predicted class: ")
print(predictedClasses(i).scalar)
}
predicted class: 7
predicted class: 2
predicted class: 1
predicted class: 0
predicted class: 9
predicted class: 1
predicted class: 9
predicted class: 9
predicted class: 2
predicted class: 9
images: Tensor[UByte] = Tensor[UByte, [10, 28, 28]] result: Tensor[Float] = Tensor[Float, [10, 10]] predictedClasses: Tensor[Long] = Tensor[Long, [10]]