以下にDeep Learning with Javascriptの3章の例題をGoogle Colabで実行してみます。 例題は、以下のGithubにありました。
ここでは3章の中から、ROC曲線の描画を通じて2値クラスの解析方法をみてみましょう。
以下のセルを実行して、本ノートブックの再ロード(F5キーを押下)してください。
一度カーネル環境が作成された後(Google Colabのセッションが有効な間)は実行する必要はありません。
!npm install -g npm@latest
!npm cache verify
!npm install -g --unsafe-perm ijavascript
!ijsinstall --install=global
!jupyter-kernelspec list
evalmachine.<anonymous>:1 !npm install -g npm@latest ^^^^^^^ SyntaxError: Unexpected identifier at createScript (vm.js:80:10) at Object.runInThisContext (vm.js:139:10) at run ([eval]:1054:15) at onRunRequest ([eval]:888:18) at onMessage ([eval]:848:13) at emitTwo (events.js:126:13) at process.emit (events.js:214:7) at emit (internal/child_process.js:772:12) at _combinedTickCallback (internal/process/next_tick.js:141:11) at process._tickCallback (internal/process/next_tick.js:180:9)
以降のセルからjavascriptが使えるようになります。
注意事項 constとlet文は使わず、varを使ってください
セル内でコマンドを実行するsh関数を定義します。
var { spawn } = require('child_process')
var sh = (cmd) => {
$$.async()
var sp = spawn(cmd, { cwd: process.cwd(), stdio: 'pipe', shell: true, encoding: 'utf-8' })
sp.stdout.on('data', data => console.log(data.toString()))
sp.stderr.on('data', data => console.error(data.toString()))
sp.on('close', () => $$.done())
}
var run_async = async (pf) => {
$$.async()
await pf()
$$.done()
}
sh('npm init -y')
Wrote to /content/package.json: { "name": "content", "version": "1.0.0", "main": "index.js", "scripts": { "test": "echo \"Error: no test specified\" && exit 1" }, "keywords": [], "author": "", "license": "ISC", "dependencies": { "@tensorflow/tfjs-node-gpu": "^1.7.4", "papaparse": "^5.2.0", "plotly-notebook-js": "^0.1.2", "xmlhttprequest": "^1.8.0" }, "devDependencies": {}, "description": "" }
sh('npm install @tensorflow/tfjs-node-gpu')
sh('npm install plotly-notebook-js')
sh('npm install papaparse')
sh('npm install xmlhttprequest')
npm WARN content@1.0.0 No description npm WARN content@1.0.0 No repository field.
+ xmlhttprequest@1.8.0 updated 1 package and audited 169 packages in 2.275s
npm WARN content@1.0.0 No description npm WARN content@1.0.0 No repository field.
1 package is looking for funding run `npm fund` for details found 0 vulnerabilities + papaparse@5.2.0 updated 1 package and audited 169 packages in 2.272s 1 package is looking for funding run `npm fund` for details found 0 vulnerabilities
npm WARN content@1.0.0 No description npm WARN content@1.0.0 No repository field.
以下のセルの実行でエラーになった場合には、ブラウザーのF5(再読み込み)を実行してもう一度試してみまてください。
最初に必要なライブラリのインスタンスを生成します。
var tf = require('@tensorflow/tfjs-node-gpu')
var Papa = require('papaparse')
var XMLHttpRequest = require("xmlhttprequest").XMLHttpRequest;
var Plot = require('plotly-notebook-js')
var NotebookPlot = Plot.createPlot().constructor
NotebookPlot.prototype._toHtml = NotebookPlot.prototype.render
node-pre-gyp info This Node instance does not support builds for N-API version 4 node-pre-gyp info This Node instance does not support builds for N-API version 5 node-pre-gyp info This Node instance does not support builds for N-API version 4 node-pre-gyp info This Node instance does not support builds for N-API version 5
[Function]
(node:1951) Warning: N-API is an experimental feature and could change at any time.
3章のデータは、csvファイルが以下のURLで公開されています。loadCsv関数を使ってダウンロードし、各カラムの値をセットします。
var BASE_URL =
'https://gist.githubusercontent.com/ManrajGrover/6589d3fd3eb9a0719d2a83128741dfc1/raw/d0a86602a87bfe147c240e87e6a9641786cafc19/';
function loadCsv(filename, data) {
const url = `${BASE_URL}${filename}.csv`;
Papa.parse(url, {
download: true,
header: true,
complete: (results) => {
console.log(`got ${filename}`);
//console.log(results['data']);
data[filename] = results['data'].map((row) => {
return Object.keys(row).sort().map(key => parseFloat(row[key]));
});
}
})
}
ダウンロードしたデータは、dataset変数にセットします。
var dataset = {}
loadCsv('train-data', dataset)
loadCsv('train-target', dataset)
loadCsv('test-data', dataset)
loadCsv('test-target', dataset)
got test-target got train-target got train-data
訓練用データtrain-dataにどのような値が入っているかみてみましょう。
console.log(dataset['train-data'][0])
[ -1, -1, -1, -1, -1, 1, 1, 1, -1, -1, -1, 1, 1, 1, 1, -1, 1, 1, -1, 0, 1, 1, -1, 1, -1, -1, -1, 1, -1, -1 ]
train-target, test-targetを1個の配列に納めて、モデルの定数をセットします。
dataset['train-target'] = Float32Array.from([].concat.apply([], dataset['train-target']))
dataset['test-target'] = Float32Array.from([].concat.apply([], dataset['test-target']))
var NUM_FEATURES = 30
var NUM_CLASSES = 2
var EPOCHS = 100
var BATCH_SIZE = 350
var trainSize = dataset['train-data'].length
var testSize = dataset['test-data'].length
Float32Array [ 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 1, 1, 0, 1, 0, ... 5427 more items ]
正、負の2値を分類する場合の予測結果の評価尺度は、正解率(accuracy)だけではありません。
しましまサイトのF値からaccuracy, precision, recallの定義を引用します。
真結果 | |||
---|---|---|---|
正 | 負 | ||
予測結果 | 正 | TP | FP |
負 | FN | TN |
precisionとrecallはなかなか両立しないため、ROC曲線より上の面積が小さくなるポイントを採用します。
ROC曲線は、縦軸に$recall=\frac{TP}{TP + FN}$、横軸に$\frac{FP}{FP + TN}$を取った曲線です。
ROC曲線のデータをセットするcalcROC関数を原著のdrawROCを参考に作りました。
面白いことに、Tensorflow.jsには、recallは用意されているのですが、falsePositiveRateは無いためサンプルコードには、以下のように定義されていました。
// falsePositives, trueNegatives, falsePositiveRateは、Tensorflow.jsが
// tf.metrics.falsePositiveRateをサポートするまでの暫定版
function falsePositives(yTrue, yPred) {
return tf.tidy(() => {
const one = tf.scalar(1);
const zero = tf.scalar(0);
return tf.logicalAnd(yTrue.equal(zero), yPred.equal(one))
.sum()
.cast('float32');
})
}
function trueNegatives(yTrue, yPred) {
return tf.tidy(() => {
const zero = tf.scalar(0);
return tf.logicalAnd(yTrue.equal(zero), yPred.equal(zero))
.sum()
.cast('float32');
})
}
function falsePositiveRate(yTrue, yPred) {
return tf.tidy(() => {
const fp = falsePositives(yTrue, yPred);
const tn = trueNegatives(yTrue, yPred);
return fp.div(fp.add(tn));
})
}
function calcROC(ROCLines, targets, probs, epoch) {
return tf.tidy(() => {
const thresholds = [
0.0, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55,
0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.92, 0.94, 0.96, 0.98, 1.0
]
const tprs = [] // True positive rates.
const fprs = [] // False positive rates.
let area = 0
for (let i = 0; i < thresholds.length; ++i) {
const threshold = thresholds[i]
const threshPredictions = tf.tidy(() => {
const condition = probs.greater(tf.scalar(threshold));
return tf.where(condition, tf.onesLike(probs), tf.zerosLike(probs));
}).as1D()
const fpr = falsePositiveRate(targets, threshPredictions).dataSync()[0]
const tpr = tf.metrics.recall(targets, threshPredictions).dataSync()[0]
fprs.push(fpr)
tprs.push(tpr)
// Accumulate to area for AUC calculation.
if (i > 0) {
area += (tprs[i] + tprs[i - 1]) * (fprs[i - 1] - fprs[i]) / 2
}
}
ROCLines.push({
x: fprs,
y: tprs,
name: epoch,
mode: 'lines'
})
return area;
});
}
入力のデータ数が30個、中間層の数を100個、最終結果は1個とし、活性化関数にsigmoidを使用します。
optimizerにはadam、損失関数は、binaryCrossentropy、尺度にmetricsを指定してcompileします。
var model = tf.sequential();
model.add(tf.layers.dense(
{inputShape: [NUM_FEATURES], units: 100, activation: 'sigmoid'}));
model.add(tf.layers.dense({units: 100, activation: 'sigmoid'}));
model.add(tf.layers.dense({units: 1, activation: 'sigmoid'}));
model.compile(
{optimizer: 'adam', loss: 'binaryCrossentropy', metrics: ['accuracy']});
var trainAcc = []
var trainLoss = []
var validAcc = []
var validLoss = []
var ROCLines = []
var trainData = tf.tensor2d(dataset['train-data'], [trainSize, NUM_FEATURES])
var trainTarget = tf.tensor1d(dataset['train-target'])
var testData = tf.tensor2d(dataset['test-data'], [testSize, NUM_FEATURES])
var testTarget = tf.tensor1d(dataset['test-target'])
fitメソッドで、最適化を実行します。ROCは、callbacksで0, 2, 4と25毎に出力します。
fitメソッドは最適化の結果を返すPromiseが戻されるので、thenで結果を受け取り、グルーバル変数trainAcc、trainLoss、validAcc、validLossにセットします。
model.fit(
trainData,
trainTarget,
{
batchsize: BATCH_SIZE,
epochs: EPOCHS,
validationSplit: 0.2,
callbacks: {
onEpochBegin: async (epoch) => {
if ((epoch + 1) % 25 === 0 || epoch === 0 || epoch === 2 || epoch === 4) {
const probs = model.predict(testData);
auc = calcROC(ROCLines, testTarget, probs, epoch+1);
}
}
}
}
).then(info => {
trainAcc = info.history.acc
trainLoss = info.history.loss
validAcc = info.history.val_acc
validLoss = info.history.val_loss
});
Epoch 1 / 100
5672ms 1283us/step - acc=0.789 loss=0.506 val_acc=0.890 val_loss=0.330 Epoch 2 / 100
3564ms 806us/step - acc=0.917 loss=0.228 val_acc=0.909 val_loss=0.241 Epoch 3 / 100
3796ms 858us/step - acc=0.923 loss=0.196 val_acc=0.914 val_loss=0.224 Epoch 4 / 100
3412ms 772us/step - acc=0.923 loss=0.190 val_acc=0.912 val_loss=0.220 Epoch 5 / 100
3855ms 872us/step - acc=0.929 loss=0.185 val_acc=0.916 val_loss=0.214 Epoch 6 / 100
3598ms 814us/step - acc=0.927 loss=0.184 val_acc=0.915 val_loss=0.212 Epoch 7 / 100
3658ms 827us/step - acc=0.928 loss=0.183 val_acc=0.918 val_loss=0.211 Epoch 8 / 100
3470ms 785us/step - acc=0.929 loss=0.183 val_acc=0.916 val_loss=0.206 Epoch 9 / 100
3636ms 822us/step - acc=0.927 loss=0.184 val_acc=0.907 val_loss=0.224 Epoch 10 / 100
3450ms 780us/step - acc=0.929 loss=0.184 val_acc=0.918 val_loss=0.209 Epoch 11 / 100
3808ms 861us/step - acc=0.928 loss=0.184 val_acc=0.916 val_loss=0.219 Epoch 12 / 100
最適化によって正確率(acc)をプロットします。
var x = Array.from(Array(EPOCHS).keys()).map(v => v + 1)
var trainAccLine = {
x: x,
y: trainAcc,
name: 'Train acc',
mode: 'lines'
}
var validAccLine = {
x: x,
y: validAcc,
name: 'Validation acc',
mode: 'lines'
}
Plot.createPlot([trainAccLine, validAccLine], {
width: 600,
title: 'Model Accuracy',
xaxis: {
title: 'EPOCH'
},
yaxis: {
title: 'Accuracy'
}
})
次に損失関数の値をプロットします。60EPOCHあたりで検証用データの損失関数の値が下止まりしているように見えます。
var trainLossLine = {
x: x,
y: trainLoss,
name: 'Train loss',
mode: 'lines'
}
var validLossLine = {
x: x,
y: validLoss,
name: 'Validation loss',
mode: 'lines'
}
Plot.createPlot([trainLossLine, validLossLine], {
width: 600,
title: 'Model Accuracy',
xaxis: {
title: 'EPOCH'
},
yaxis: {
title: 'Accuracy'
}
})
最後にROC曲線をプロットします。75と100EPOCHがほぼ同じになるので、損失関数の結果を考慮すると75EPOCHのモデルがよいと思われます。
Plot.createPlot(ROCLines, {
width: 500,
height: 500,
title: 'ROC curve',
xaxis: {
title: 'FPR'
},
yaxis: {
title: 'TPR'
}
})