In this notebooks, we build and train a deep CNN model using TensorFlow.js and visualize the predictions from the trained model on MNIST datasets. You can build and train deep neural network machine learning mode with tslab and Tensorflow.js without using Python.
Don't run this notebook on mybinder.org. The training of the CNN model in this notebook is very heavy and it will not finish on mybinder.org. Please try this notebook in your local environment with enough CPU power.
/**
* Copyright 2018 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*
* This code was branched from
* https://github.com/tensorflow/tfjs-examples/blob/master/mnist-node/
* to demostrate Tensorflow in tslab.
*/
import * as tf from '@tensorflow/tfjs-node'
import Jimp from 'jimp';
import {promisify} from 'util';
import {dataset as mnist} from '../lib/mnist';
import {display} from 'tslab';
import * as tslab from 'tslab';
display.html('<h3>TensorFlow.js versions</h3>')
console.log(tf.version)
display.html('<h3>tslab versions</h3>')
console.log(tslab.versions);
{ 'tfjs-core': '1.3.1', 'tfjs-data': '1.3.1', 'tfjs-layers': '1.3.1', 'tfjs-converter': '1.3.1', tfjs: '1.3.1', 'tfjs-node': '1.3.1' }
{ tslab: '1.0.7', typescript: '3.7.3', node: 'v12.13.0' }
await mnist.loadData();
async function toPng(images: tf.Tensor4D, start: number, size: number): Promise<Buffer[]> {
// Note: mnist.getTrainData().images.slice([index], [1]) is slow.
let arry = images.slice([start], [size]).flatten().arraySync();
let ret: Buffer[] = [];
for (let i = 0; i < size; i++) {
let raw = [];
for (const v of arry.slice(i * 28 * 28, (i+1)*28*28)) {
raw.push(...[v*255, v*255, v*255, 255])
}
let img = await promisify((cb: (err, v: Jimp)=>any) => {
new Jimp({ data: Buffer.from(raw), width: 28, height: 28 }, cb);
})();
ret.push(await img.getBufferAsync(Jimp.MIME_PNG));
}
return ret;
}
const model = tf.sequential();
model.add(tf.layers.conv2d({
inputShape: [28, 28, 1],
filters: 32,
kernelSize: 3,
activation: 'relu',
}));
model.add(tf.layers.conv2d({
filters: 32,
kernelSize: 3,
activation: 'relu',
}));
model.add(tf.layers.maxPooling2d({poolSize: [2, 2]}));
model.add(tf.layers.conv2d({
filters: 64,
kernelSize: 3,
activation: 'relu',
}));
model.add(tf.layers.conv2d({
filters: 64,
kernelSize: 3,
activation: 'relu',
}));
model.add(tf.layers.maxPooling2d({poolSize: [2, 2]}));
model.add(tf.layers.flatten());
model.add(tf.layers.dropout({rate: 0.25}));
model.add(tf.layers.dense({units: 512, activation: 'relu'}));
model.add(tf.layers.dropout({rate: 0.5}));
model.add(tf.layers.dense({units: 10, activation: 'softmax'}));
const optimizer = 'rmsprop';
model.compile({
optimizer: optimizer,
loss: 'categoricalCrossentropy',
metrics: ['accuracy'],
});
async function train(epochs, batchSize, modelSavePath) {
const {images: trainImages, labels: trainLabels} = mnist.getTrainData();
model.summary();
let epochBeginTime;
let millisPerStep;
const validationSplit = 0.15;
const numTrainExamplesPerEpoch =
trainImages.shape[0] * (1 - validationSplit);
const numTrainBatchesPerEpoch =
Math.ceil(numTrainExamplesPerEpoch / batchSize);
await model.fit(trainImages, trainLabels, {
epochs,
batchSize,
validationSplit
});
const {images: testImages, labels: testLabels} = mnist.getTestData();
const evalOutput = model.evaluate(testImages, testLabels);
console.log(
`\nEvaluation result:\n` +
` Loss = ${evalOutput[0].dataSync()[0].toFixed(3)}; `+
`Accuracy = ${evalOutput[1].dataSync()[0].toFixed(3)}`);
if (modelSavePath != null) {
await model.save(`file://${modelSavePath}`);
console.log(`Saved model to path: ${modelSavePath}`);
}
}
// Hack to suppress the progress bar
process.stderr.isTTY = false;
const epochs = 20;
const batchSize = 128;
const modelSavePath = 'mnist'
await train(epochs, batchSize, modelSavePath);
_________________________________________________________________ Layer (type) Output shape Param # ================================================================= conv2d_Conv2D1 (Conv2D) [null,26,26,32] 320 _________________________________________________________________ conv2d_Conv2D2 (Conv2D) [null,24,24,32] 9248 _________________________________________________________________ max_pooling2d_MaxPooling2D1 [null,12,12,32] 0 _________________________________________________________________ conv2d_Conv2D3 (Conv2D) [null,10,10,64] 18496 _________________________________________________________________ conv2d_Conv2D4 (Conv2D) [null,8,8,64] 36928 _________________________________________________________________ max_pooling2d_MaxPooling2D2 [null,4,4,64] 0 _________________________________________________________________ flatten_Flatten1 (Flatten) [null,1024] 0 _________________________________________________________________ dropout_Dropout1 (Dropout) [null,1024] 0 _________________________________________________________________ dense_Dense1 (Dense) [null,512] 524800 _________________________________________________________________ dropout_Dropout2 (Dropout) [null,512] 0 _________________________________________________________________ dense_Dense2 (Dense) [null,10] 5130 ================================================================= Total params: 594922 Trainable params: 594922 Non-trainable params: 0 _________________________________________________________________ Epoch 1 / 20
40717ms 798us/step - acc=0.920 loss=0.245 val_acc=0.979 val_loss=0.0735 Epoch 2 / 20
39984ms 784us/step - acc=0.980 loss=0.0674 val_acc=0.990 val_loss=0.0360 Epoch 3 / 20
40109ms 786us/step - acc=0.985 loss=0.0491 val_acc=0.990 val_loss=0.0371 Epoch 4 / 20
42172ms 827us/step - acc=0.988 loss=0.0379 val_acc=0.992 val_loss=0.0294 Epoch 5 / 20
42451ms 832us/step - acc=0.990 loss=0.0320 val_acc=0.992 val_loss=0.0285 Epoch 6 / 20
42674ms 837us/step - acc=0.991 loss=0.0283 val_acc=0.987 val_loss=0.0481 Epoch 7 / 20
42504ms 833us/step - acc=0.993 loss=0.0234 val_acc=0.992 val_loss=0.0263 Epoch 8 / 20
43120ms 845us/step - acc=0.993 loss=0.0218 val_acc=0.993 val_loss=0.0263 Epoch 9 / 20
42818ms 840us/step - acc=0.994 loss=0.0191 val_acc=0.993 val_loss=0.0274 Epoch 10 / 20
43198ms 847us/step - acc=0.994 loss=0.0177 val_acc=0.994 val_loss=0.0213 Epoch 11 / 20
43481ms 853us/step - acc=0.995 loss=0.0150 val_acc=0.994 val_loss=0.0253 Epoch 12 / 20
43164ms 846us/step - acc=0.995 loss=0.0154 val_acc=0.994 val_loss=0.0263 Epoch 13 / 20
42980ms 843us/step - acc=0.995 loss=0.0135 val_acc=0.994 val_loss=0.0251 Epoch 14 / 20
43289ms 849us/step - acc=0.996 loss=0.0126 val_acc=0.994 val_loss=0.0255 Epoch 15 / 20
43104ms 845us/step - acc=0.996 loss=0.0113 val_acc=0.992 val_loss=0.0333 Epoch 16 / 20
43385ms 851us/step - acc=0.997 loss=0.0102 val_acc=0.993 val_loss=0.0320 Epoch 17 / 20
43223ms 848us/step - acc=0.996 loss=0.0106 val_acc=0.993 val_loss=0.0308 Epoch 18 / 20
43164ms 846us/step - acc=0.997 loss=9.44e-3 val_acc=0.994 val_loss=0.0329 Epoch 19 / 20
43324ms 849us/step - acc=0.997 loss=8.45e-3 val_acc=0.994 val_loss=0.0319 Epoch 20 / 20
42775ms 839us/step - acc=0.997 loss=8.62e-3 val_acc=0.994 val_loss=0.0270
Evaluation result:
Loss = 0.021; Accuracy = 0.994
Saved model to path: mnist
undefined
const predicted =
tf.argMax(model.predict(mnist.getTestData().images) as tf.Tensor, 1).arraySync() as number[];
{
let start = 100;
let size = 32;
const html: string[] = [];
const pngs = await toPng(mnist.getTestData().images, start, size);
html.push('<div style="display:flex;flex-wrap:wrap;max-width:480px">')
for (let i = 0; i < size; i++) {
const pred = predicted[i + start];
html.push('<div>');
html.push(`<img src="data:image/png;base64,${pngs[i].toString('base64')}">`);
html.push(`<div style="text-align:center">${pred}</div>`)
html.push('</div>');
}
html.push('</div>')
display.html(html.join('\n'));
}
// Sow examples the model failed to predict correct labels.
{
let start = 100;
let size = 2000;
const html: string[] = [];
const pngs = await toPng(mnist.getTestData().images, start, size);
const labels = tf.argMax(mnist.getTestData().labels, 1).arraySync() as number[];
html.push('<div style="display:flex;flex-wrap:wrap;max-width:480px">')
for (let i = 0; i < size; i++) {
const pred = predicted[i + start];
const label = labels[i + start];
if (pred === label) {
continue;
}
html.push('<div style="border:solid black 1px;text-align:center;margin:4px">');
html.push(`<img style="display:inline-block" src="data:image/png;base64,${pngs[i].toString('base64')}">`);
html.push(`<div style="text-align:center">Label: ${label}, Prediction: ${pred}</div>`)
html.push('</div>');
}
html.push('</div>')
display.html(html.join('\n'));
}