This is a tutorial based on the fast.ai Practical Deep Learning course and their Cats and Dogs Redux
This Jupyter Notebook uses the lein-jupyter plugin to be able to execute Clojure code in
project setting. The first time that you run it you will need to install the kernal with lein jupyter install-kernel
.
After that you can open the notebook in the project directory with lein jupyter notebook
The Deep Learning library in Clojure that we are using is called Cortex. This example borrows heavily from the Resnet-retrain example in the repo.
The goal of this is to take the dataset from the Kaggle Cats and Dogs Redux Competiton and train a model. We will then run the model on the Kaggle testing data and submit the results to see how we compare.
You will need to do a bit of setup first. Download the train.zip
and test.zip
from the Kaggle Data page.
Make a /data
directory in the project directory and put the train.zip
file there and unzip it.
You should now have a directory structure that looks like data/train
with inside the directory lots of pictures like cat.156.jpg
and dog.1.jpg
The next thing that we need to do is to run the .get-resnet50.sh
in the root of the project. What this is going to do is download a pretrained model that is meant to classify 1000 images types from the ImageNet competition. This model and its weights have been translated into a Cortex network format nippy
(require '[clojure.java.shell :as shell])
(-> (shell/sh "./get-resnet50.sh") :out println)
(-> (shell/sh "ls" "models") :out println)
resnet50.nippy
(-> (shell/sh "ls" "-d" "data/train") :out println)
data/train
To do this, let's require some of the functions in the core namespace of this project. In fact, let's go ahead and create a proper namespace for this worksheet and require in the libs we need.
(ns cats-dogs-cortex-redux.notebook
(:require [cats-dogs-cortex-redux.core :as cats-dogs]
[clojure.java.shell :as shell]))
Now we are trying to get the cats and dogs data that's currently in the data/train
directory into the form of directories that will look like this:
-data
-cats-dogs-training
-cat
1110.png
...
-dog
12416.png
...
-cats-dogs-testing
-cat
11.png
...
-dog
12.png
...
One of the main things we want to do is take all the images in the data/train
directory, shuffle them, and then split them into 85% of the pictures for training and 15% of the pictures for validation testing. We also want to have a directory structure so that only the cat pictures are in the cat directory and dog pictures in the dog directory. Also all of the pictures should be resized to 224x224 to match up with the RESNET model's input, (it's expecting a 224x224 image).
There is a build-image-data
function in the cats-dogs namespace that will do this for us.
(cats-dogs/build-image-data)
Building the image data with a test-training split of 0.85 training files 21250 testing files 3750
(-> (shell/sh "ls" "data/cats-dogs-training/") :out println)
cat dog
(-> (shell/sh "ls" "data/cats-dogs-testing/") :out println)
cat dog
The main function our code calls to create the model is this one, (load-network "models/resnet50.nippy" :fc1000 layers-to-add)
Let's take a closer look at it.
(def layers-to-add
[(layers/linear 2 :id :fc2)
(layers/softmax :id :labels)])
(defn load-network
[network-file chop-layer top-layers]
(let [network (util/read-nippy-file network-file)
;; remove last layer(s)
chopped-net (network/dissoc-layers-from-network network chop-layer)
;; set layers to non-trainable
nodes (get-in chopped-net [:compute-graph :nodes]) ;;=> {:linear-1 {<params>}
new-node-params (mapv (fn [params] (assoc params :non-trainable? true)) (vals nodes))
frozen-nodes (zipmap (keys nodes) new-node-params)
frozen-net (assoc-in chopped-net [:compute-graph :nodes] frozen-nodes)
;; add top layers
modified-net (network/assoc-layers-to-network frozen-net (flatten top-layers))]
modified-net))
What it is doing is taking in that resnet50 trained models and chopping off the last layer and setting all the layers to non-trainable so that the training won't spend any time retraining those layers and leave them frozen. Then it tacks on the layers that we want to add which are the linear layer of size 2 for our 2 classes (cat and dog of classifcation) and a softmax activation layer that will return the probalities that the image is a cat or dog with the values summing to 1.
Let's load up the model and take a look
(require '[cortex.util :as util])
(def plain-res50 (util/read-nippy-file "models/resnet50.nippy"))
#'cats-dogs-cortex-redux.notebook/plain-res50
The cortex network is just data so we can take a look at the layers
(keys (:compute-graph plain-res50))
(:nodes :edges :buffers :streams)
(clojure.pprint/pprint (sort (keys (get-in plain-res50 [:compute-graph :nodes]))))
(:activation_1 :activation_10 :activation_10-split :activation_11 :activation_12 :activation_13 :activation_13-split :activation_14 :activation_15 :activation_16 :activation_16-split :activation_17 :activation_18 :activation_19 :activation_19-split :activation_2 :activation_20 :activation_21 :activation_22 :activation_22-split :activation_23 :activation_24 :activation_25 :activation_25-split :activation_26 :activation_27 :activation_28 :activation_28-split :activation_29 :activation_3 :activation_30 :activation_31 :activation_31-split :activation_32 :activation_33 :activation_34 :activation_34-split :activation_35 :activation_36 :activation_37 :activation_37-split :activation_38 :activation_39 :activation_4 :activation_4-split :activation_40 :activation_40-split :activation_41 :activation_42 :activation_43 :activation_43-split :activation_44 :activation_45 :activation_46 :activation_46-split :activation_47 :activation_48 :activation_49 :activation_5 :activation_6 :activation_7 :activation_7-split :activation_8 :activation_9 :add_1 :add_10 :add_11 :add_12 :add_13 :add_14 :add_15 :add_16 :add_2 :add_3 :add_4 :add_5 :add_6 :add_7 :add_8 :add_9 :avg_pool :bn2a_branch1 :bn2a_branch2a :bn2a_branch2b :bn2a_branch2c :bn2b_branch2a :bn2b_branch2b :bn2b_branch2c :bn2c_branch2a :bn2c_branch2b :bn2c_branch2c :bn3a_branch1 :bn3a_branch2a :bn3a_branch2b :bn3a_branch2c :bn3b_branch2a :bn3b_branch2b :bn3b_branch2c :bn3c_branch2a :bn3c_branch2b :bn3c_branch2c :bn3d_branch2a :bn3d_branch2b :bn3d_branch2c :bn4a_branch1 :bn4a_branch2a :bn4a_branch2b :bn4a_branch2c :bn4b_branch2a :bn4b_branch2b :bn4b_branch2c :bn4c_branch2a :bn4c_branch2b :bn4c_branch2c :bn4d_branch2a :bn4d_branch2b :bn4d_branch2c :bn4e_branch2a :bn4e_branch2b :bn4e_branch2c :bn4f_branch2a :bn4f_branch2b :bn4f_branch2c :bn5a_branch1 :bn5a_branch2a :bn5a_branch2b :bn5a_branch2c :bn5b_branch2a :bn5b_branch2b :bn5b_branch2c :bn5c_branch2a :bn5c_branch2b :bn5c_branch2c :bn_conv1 :conv1 :data :fc1000 :fc1000-activation :max_pooling2d_1 :max_pooling2d_1-split :res2a_branch1 :res2a_branch2a :res2a_branch2b :res2a_branch2c :res2b_branch2a :res2b_branch2b :res2b_branch2c :res2c_branch2a :res2c_branch2b :res2c_branch2c :res3a_branch1 :res3a_branch2a :res3a_branch2b :res3a_branch2c :res3b_branch2a :res3b_branch2b :res3b_branch2c :res3c_branch2a :res3c_branch2b :res3c_branch2c :res3d_branch2a :res3d_branch2b :res3d_branch2c :res4a_branch1 :res4a_branch2a :res4a_branch2b :res4a_branch2c :res4b_branch2a :res4b_branch2b :res4b_branch2c :res4c_branch2a :res4c_branch2b :res4c_branch2c :res4d_branch2a :res4d_branch2b :res4d_branch2c :res4e_branch2a :res4e_branch2b :res4e_branch2c :res4f_branch2a :res4f_branch2b :res4f_branch2c :res5a_branch1 :res5a_branch2a :res5a_branch2b :res5a_branch2c :res5b_branch2a :res5b_branch2b :res5b_branch2c :res5c_branch2a :res5c_branch2b :res5c_branch2c :softmax-loss-1)
There is a function to help print out the network
(require '[cortex.nn.network :as network]
'[cortex.nn.traverse :as traverse])
(network/print-layer-summary plain-res50 (traverse/training-traversal plain-res50))
| type | input | output | :bias | :means | :scale | :variances | :weights | |----------------------+---------------------+---------------------+--------+--------+--------+------------+-------------| | :convolutional | 3x224x224 - 150528 | 64x112x112 - 802816 | [64] | | | | [64 147] | | :batch-normalization | 64x112x112 - 802816 | 64x112x112 - 802816 | [64] | [64] | [64] | [64] | | | :relu | 64x112x112 - 802816 | 64x112x112 - 802816 | | | | | | | :max-pooling | 64x112x112 - 802816 | 64x55x55 - 193600 | | | | | | | :split | 64x55x55 - 193600 | 64x55x55 - 193600 | | | | | | | :convolutional | 64x55x55 - 193600 | 64x55x55 - 193600 | [64] | | | | [64 64] | | :batch-normalization | 64x55x55 - 193600 | 64x55x55 - 193600 | [64] | [64] | [64] | [64] | | | :relu | 64x55x55 - 193600 | 64x55x55 - 193600 | | | | | | | :convolutional | 64x55x55 - 193600 | 64x55x55 - 193600 | [64] | | | | [64 576] | | :batch-normalization | 64x55x55 - 193600 | 64x55x55 - 193600 | [64] | [64] | [64] | [64] | | | :relu | 64x55x55 - 193600 | 64x55x55 - 193600 | | | | | | | :convolutional | 64x55x55 - 193600 | 256x55x55 - 774400 | [256] | | | | [256 64] | | :batch-normalization | 256x55x55 - 774400 | 256x55x55 - 774400 | [256] | [256] | [256] | [256] | | | :convolutional | 64x55x55 - 193600 | 256x55x55 - 774400 | [256] | | | | [256 64] | | :batch-normalization | 256x55x55 - 774400 | 256x55x55 - 774400 | [256] | [256] | [256] | [256] | | | :join | 256x55x55 - 774400 | 256x55x55 - 774400 | | | | | | | :relu | 256x55x55 - 774400 | 256x55x55 - 774400 | | | | | | | :split | 256x55x55 - 774400 | 256x55x55 - 774400 | | | | | | | :convolutional | 256x55x55 - 774400 | 64x55x55 - 193600 | [64] | | | | [64 256] | | :batch-normalization | 64x55x55 - 193600 | 64x55x55 - 193600 | [64] | [64] | [64] | [64] | | | :relu | 64x55x55 - 193600 | 64x55x55 - 193600 | | | | | | | :convolutional | 64x55x55 - 193600 | 64x55x55 - 193600 | [64] | | | | [64 576] | | :batch-normalization | 64x55x55 - 193600 | 64x55x55 - 193600 | [64] | [64] | [64] | [64] | | | :relu | 64x55x55 - 193600 | 64x55x55 - 193600 | | | | | | | :convolutional | 64x55x55 - 193600 | 256x55x55 - 774400 | [256] | | | | [256 64] | | :batch-normalization | 256x55x55 - 774400 | 256x55x55 - 774400 | [256] | [256] | [256] | [256] | | | :join | 256x55x55 - 774400 | 256x55x55 - 774400 | | | | | | | :relu | 256x55x55 - 774400 | 256x55x55 - 774400 | | | | | | | :split | 256x55x55 - 774400 | 256x55x55 - 774400 | | | | | | | :convolutional | 256x55x55 - 774400 | 64x55x55 - 193600 | [64] | | | | [64 256] | | :batch-normalization | 64x55x55 - 193600 | 64x55x55 - 193600 | [64] | [64] | [64] | [64] | | | :relu | 64x55x55 - 193600 | 64x55x55 - 193600 | | | | | | | :convolutional | 64x55x55 - 193600 | 64x55x55 - 193600 | [64] | | | | [64 576] | | :batch-normalization | 64x55x55 - 193600 | 64x55x55 - 193600 | [64] | [64] | [64] | [64] | | | :relu | 64x55x55 - 193600 | 64x55x55 - 193600 | | | | | | | :convolutional | 64x55x55 - 193600 | 256x55x55 - 774400 | [256] | | | | [256 64] | | :batch-normalization | 256x55x55 - 774400 | 256x55x55 - 774400 | [256] | [256] | [256] | [256] | | | :join | 256x55x55 - 774400 | 256x55x55 - 774400 | | | | | | | :relu | 256x55x55 - 774400 | 256x55x55 - 774400 | | | | | | | :split | 256x55x55 - 774400 | 256x55x55 - 774400 | | | | | | | :convolutional | 256x55x55 - 774400 | 128x28x28 - 100352 | [128] | | | | [128 256] | | :batch-normalization | 128x28x28 - 100352 | 128x28x28 - 100352 | [128] | [128] | [128] | [128] | | | :relu | 128x28x28 - 100352 | 128x28x28 - 100352 | | | | | | | :convolutional | 128x28x28 - 100352 | 128x28x28 - 100352 | [128] | | | | [128 1152] | | :batch-normalization | 128x28x28 - 100352 | 128x28x28 - 100352 | [128] | [128] | [128] | [128] | | | :relu | 128x28x28 - 100352 | 128x28x28 - 100352 | | | | | | | :convolutional | 128x28x28 - 100352 | 512x28x28 - 401408 | [512] | | | | [512 128] | | :batch-normalization | 512x28x28 - 401408 | 512x28x28 - 401408 | [512] | [512] | [512] | [512] | | | :convolutional | 256x55x55 - 774400 | 512x28x28 - 401408 | [512] | | | | [512 256] | | :batch-normalization | 512x28x28 - 401408 | 512x28x28 - 401408 | [512] | [512] | [512] | [512] | | | :join | 512x28x28 - 401408 | 512x28x28 - 401408 | | | | | | | :relu | 512x28x28 - 401408 | 512x28x28 - 401408 | | | | | | | :split | 512x28x28 - 401408 | 512x28x28 - 401408 | | | | | | | :convolutional | 512x28x28 - 401408 | 128x28x28 - 100352 | [128] | | | | [128 512] | | :batch-normalization | 128x28x28 - 100352 | 128x28x28 - 100352 | [128] | [128] | [128] | [128] | | | :relu | 128x28x28 - 100352 | 128x28x28 - 100352 | | | | | | | :convolutional | 128x28x28 - 100352 | 128x28x28 - 100352 | [128] | | | | [128 1152] | | :batch-normalization | 128x28x28 - 100352 | 128x28x28 - 100352 | [128] | [128] | [128] | [128] | | | :relu | 128x28x28 - 100352 | 128x28x28 - 100352 | | | | | | | :convolutional | 128x28x28 - 100352 | 512x28x28 - 401408 | [512] | | | | [512 128] | | :batch-normalization | 512x28x28 - 401408 | 512x28x28 - 401408 | [512] | [512] | [512] | [512] | | | :join | 512x28x28 - 401408 | 512x28x28 - 401408 | | | | | | | :relu | 512x28x28 - 401408 | 512x28x28 - 401408 | | | | | | | :split | 512x28x28 - 401408 | 512x28x28 - 401408 | | | | | | | :convolutional | 512x28x28 - 401408 | 128x28x28 - 100352 | [128] | | | | [128 512] | | :batch-normalization | 128x28x28 - 100352 | 128x28x28 - 100352 | [128] | [128] | [128] | [128] | | | :relu | 128x28x28 - 100352 | 128x28x28 - 100352 | | | | | | | :convolutional | 128x28x28 - 100352 | 128x28x28 - 100352 | [128] | | | | [128 1152] | | :batch-normalization | 128x28x28 - 100352 | 128x28x28 - 100352 | [128] | [128] | [128] | [128] | | | :relu | 128x28x28 - 100352 | 128x28x28 - 100352 | | | | | | | :convolutional | 128x28x28 - 100352 | 512x28x28 - 401408 | [512] | | | | [512 128] | | :batch-normalization | 512x28x28 - 401408 | 512x28x28 - 401408 | [512] | [512] | [512] | [512] | | | :join | 512x28x28 - 401408 | 512x28x28 - 401408 | | | | | | | :relu | 512x28x28 - 401408 | 512x28x28 - 401408 | | | | | | | :split | 512x28x28 - 401408 | 512x28x28 - 401408 | | | | | | | :convolutional | 512x28x28 - 401408 | 128x28x28 - 100352 | [128] | | | | [128 512] | | :batch-normalization | 128x28x28 - 100352 | 128x28x28 - 100352 | [128] | [128] | [128] | [128] | | | :relu | 128x28x28 - 100352 | 128x28x28 - 100352 | | | | | | | :convolutional | 128x28x28 - 100352 | 128x28x28 - 100352 | [128] | | | | [128 1152] | | :batch-normalization | 128x28x28 - 100352 | 128x28x28 - 100352 | [128] | [128] | [128] | [128] | | | :relu | 128x28x28 - 100352 | 128x28x28 - 100352 | | | | | | | :convolutional | 128x28x28 - 100352 | 512x28x28 - 401408 | [512] | | | | [512 128] | | :batch-normalization | 512x28x28 - 401408 | 512x28x28 - 401408 | [512] | [512] | [512] | [512] | | | :join | 512x28x28 - 401408 | 512x28x28 - 401408 | | | | | | | :relu | 512x28x28 - 401408 | 512x28x28 - 401408 | | | | | | | :split | 512x28x28 - 401408 | 512x28x28 - 401408 | | | | | | | :convolutional | 512x28x28 - 401408 | 256x14x14 - 50176 | [256] | | | | [256 512] | | :batch-normalization | 256x14x14 - 50176 | 256x14x14 - 50176 | [256] | [256] | [256] | [256] | | | :relu | 256x14x14 - 50176 | 256x14x14 - 50176 | | | | | | | :convolutional | 256x14x14 - 50176 | 256x14x14 - 50176 | [256] | | | | [256 2304] | | :batch-normalization | 256x14x14 - 50176 | 256x14x14 - 50176 | [256] | [256] | [256] | [256] | | | :relu | 256x14x14 - 50176 | 256x14x14 - 50176 | | | | | | | :convolutional | 256x14x14 - 50176 | 1024x14x14 - 200704 | [1024] | | | | [1024 256] | | :batch-normalization | 1024x14x14 - 200704 | 1024x14x14 - 200704 | [1024] | [1024] | [1024] | [1024] | | | :convolutional | 512x28x28 - 401408 | 1024x14x14 - 200704 | [1024] | | | | [1024 512] | | :batch-normalization | 1024x14x14 - 200704 | 1024x14x14 - 200704 | [1024] | [1024] | [1024] | [1024] | | | :join | 1024x14x14 - 200704 | 1024x14x14 - 200704 | | | | | | | :relu | 1024x14x14 - 200704 | 1024x14x14 - 200704 | | | | | | | :split | 1024x14x14 - 200704 | 1024x14x14 - 200704 | | | | | | | :convolutional | 1024x14x14 - 200704 | 256x14x14 - 50176 | [256] | | | | [256 1024] | | :batch-normalization | 256x14x14 - 50176 | 256x14x14 - 50176 | [256] | [256] | [256] | [256] | | | :relu | 256x14x14 - 50176 | 256x14x14 - 50176 | | | | | | | :convolutional | 256x14x14 - 50176 | 256x14x14 - 50176 | [256] | | | | [256 2304] | | :batch-normalization | 256x14x14 - 50176 | 256x14x14 - 50176 | [256] | [256] | [256] | [256] | | | :relu | 256x14x14 - 50176 | 256x14x14 - 50176 | | | | | | | :convolutional | 256x14x14 - 50176 | 1024x14x14 - 200704 | [1024] | | | | [1024 256] | | :batch-normalization | 1024x14x14 - 200704 | 1024x14x14 - 200704 | [1024] | [1024] | [1024] | [1024] | | | :join | 1024x14x14 - 200704 | 1024x14x14 - 200704 | | | | | | | :relu | 1024x14x14 - 200704 | 1024x14x14 - 200704 | | | | | | | :split | 1024x14x14 - 200704 | 1024x14x14 - 200704 | | | | | | | :convolutional | 1024x14x14 - 200704 | 256x14x14 - 50176 | [256] | | | | [256 1024] | | :batch-normalization | 256x14x14 - 50176 | 256x14x14 - 50176 | [256] | [256] | [256] | [256] | | | :relu | 256x14x14 - 50176 | 256x14x14 - 50176 | | | | | | | :convolutional | 256x14x14 - 50176 | 256x14x14 - 50176 | [256] | | | | [256 2304] | | :batch-normalization | 256x14x14 - 50176 | 256x14x14 - 50176 | [256] | [256] | [256] | [256] | | | :relu | 256x14x14 - 50176 | 256x14x14 - 50176 | | | | | | | :convolutional | 256x14x14 - 50176 | 1024x14x14 - 200704 | [1024] | | | | [1024 256] | | :batch-normalization | 1024x14x14 - 200704 | 1024x14x14 - 200704 | [1024] | [1024] | [1024] | [1024] | | | :join | 1024x14x14 - 200704 | 1024x14x14 - 200704 | | | | | | | :relu | 1024x14x14 - 200704 | 1024x14x14 - 200704 | | | | | | | :split | 1024x14x14 - 200704 | 1024x14x14 - 200704 | | | | | | | :convolutional | 1024x14x14 - 200704 | 256x14x14 - 50176 | [256] | | | | [256 1024] | | :batch-normalization | 256x14x14 - 50176 | 256x14x14 - 50176 | [256] | [256] | [256] | [256] | | | :relu | 256x14x14 - 50176 | 256x14x14 - 50176 | | | | | | | :convolutional | 256x14x14 - 50176 | 256x14x14 - 50176 | [256] | | | | [256 2304] | | :batch-normalization | 256x14x14 - 50176 | 256x14x14 - 50176 | [256] | [256] | [256] | [256] | | | :relu | 256x14x14 - 50176 | 256x14x14 - 50176 | | | | | | | :convolutional | 256x14x14 - 50176 | 1024x14x14 - 200704 | [1024] | | | | [1024 256] | | :batch-normalization | 1024x14x14 - 200704 | 1024x14x14 - 200704 | [1024] | [1024] | [1024] | [1024] | | | :join | 1024x14x14 - 200704 | 1024x14x14 - 200704 | | | | | | | :relu | 1024x14x14 - 200704 | 1024x14x14 - 200704 | | | | | | | :split | 1024x14x14 - 200704 | 1024x14x14 - 200704 | | | | | | | :convolutional | 1024x14x14 - 200704 | 256x14x14 - 50176 | [256] | | | | [256 1024] | | :batch-normalization | 256x14x14 - 50176 | 256x14x14 - 50176 | [256] | [256] | [256] | [256] | | | :relu | 256x14x14 - 50176 | 256x14x14 - 50176 | | | | | | | :convolutional | 256x14x14 - 50176 | 256x14x14 - 50176 | [256] | | | | [256 2304] | | :batch-normalization | 256x14x14 - 50176 | 256x14x14 - 50176 | [256] | [256] | [256] | [256] | | | :relu | 256x14x14 - 50176 | 256x14x14 - 50176 | | | | | | | :convolutional | 256x14x14 - 50176 | 1024x14x14 - 200704 | [1024] | | | | [1024 256] | | :batch-normalization | 1024x14x14 - 200704 | 1024x14x14 - 200704 | [1024] | [1024] | [1024] | [1024] | | | :join | 1024x14x14 - 200704 | 1024x14x14 - 200704 | | | | | | | :relu | 1024x14x14 - 200704 | 1024x14x14 - 200704 | | | | | | | :split | 1024x14x14 - 200704 | 1024x14x14 - 200704 | | | | | | | :convolutional | 1024x14x14 - 200704 | 256x14x14 - 50176 | [256] | | | | [256 1024] | | :batch-normalization | 256x14x14 - 50176 | 256x14x14 - 50176 | [256] | [256] | [256] | [256] | | | :relu | 256x14x14 - 50176 | 256x14x14 - 50176 | | | | | | | :convolutional | 256x14x14 - 50176 | 256x14x14 - 50176 | [256] | | | | [256 2304] | | :batch-normalization | 256x14x14 - 50176 | 256x14x14 - 50176 | [256] | [256] | [256] | [256] | | | :relu | 256x14x14 - 50176 | 256x14x14 - 50176 | | | | | | | :convolutional | 256x14x14 - 50176 | 1024x14x14 - 200704 | [1024] | | | | [1024 256] | | :batch-normalization | 1024x14x14 - 200704 | 1024x14x14 - 200704 | [1024] | [1024] | [1024] | [1024] | | | :join | 1024x14x14 - 200704 | 1024x14x14 - 200704 | | | | | | | :relu | 1024x14x14 - 200704 | 1024x14x14 - 200704 | | | | | | | :split | 1024x14x14 - 200704 | 1024x14x14 - 200704 | | | | | | | :convolutional | 1024x14x14 - 200704 | 512x7x7 - 25088 | [512] | | | | [512 1024] | | :batch-normalization | 512x7x7 - 25088 | 512x7x7 - 25088 | [512] | [512] | [512] | [512] | | | :relu | 512x7x7 - 25088 | 512x7x7 - 25088 | | | | | | | :convolutional | 512x7x7 - 25088 | 512x7x7 - 25088 | [512] | | | | [512 4608] | | :batch-normalization | 512x7x7 - 25088 | 512x7x7 - 25088 | [512] | [512] | [512] | [512] | | | :relu | 512x7x7 - 25088 | 512x7x7 - 25088 | | | | | | | :convolutional | 512x7x7 - 25088 | 2048x7x7 - 100352 | [2048] | | | | [2048 512] | | :batch-normalization | 2048x7x7 - 100352 | 2048x7x7 - 100352 | [2048] | [2048] | [2048] | [2048] | | | :convolutional | 1024x14x14 - 200704 | 2048x7x7 - 100352 | [2048] | | | | [2048 1024] | | :batch-normalization | 2048x7x7 - 100352 | 2048x7x7 - 100352 | [2048] | [2048] | [2048] | [2048] | | | :join | 2048x7x7 - 100352 | 2048x7x7 - 100352 | | | | | | | :relu | 2048x7x7 - 100352 | 2048x7x7 - 100352 | | | | | | | :split | 2048x7x7 - 100352 | 2048x7x7 - 100352 | | | | | | | :convolutional | 2048x7x7 - 100352 | 512x7x7 - 25088 | [512] | | | | [512 2048] | | :batch-normalization | 512x7x7 - 25088 | 512x7x7 - 25088 | [512] | [512] | [512] | [512] | | | :relu | 512x7x7 - 25088 | 512x7x7 - 25088 | | | | | | | :convolutional | 512x7x7 - 25088 | 512x7x7 - 25088 | [512] | | | | [512 4608] | | :batch-normalization | 512x7x7 - 25088 | 512x7x7 - 25088 | [512] | [512] | [512] | [512] | | | :relu | 512x7x7 - 25088 | 512x7x7 - 25088 | | | | | | | :convolutional | 512x7x7 - 25088 | 2048x7x7 - 100352 | [2048] | | | | [2048 512] | | :batch-normalization | 2048x7x7 - 100352 | 2048x7x7 - 100352 | [2048] | [2048] | [2048] | [2048] | | | :join | 2048x7x7 - 100352 | 2048x7x7 - 100352 | | | | | | | :relu | 2048x7x7 - 100352 | 2048x7x7 - 100352 | | | | | | | :split | 2048x7x7 - 100352 | 2048x7x7 - 100352 | | | | | | | :convolutional | 2048x7x7 - 100352 | 512x7x7 - 25088 | [512] | | | | [512 2048] | | :batch-normalization | 512x7x7 - 25088 | 512x7x7 - 25088 | [512] | [512] | [512] | [512] | | | :relu | 512x7x7 - 25088 | 512x7x7 - 25088 | | | | | | | :convolutional | 512x7x7 - 25088 | 512x7x7 - 25088 | [512] | | | | [512 4608] | | :batch-normalization | 512x7x7 - 25088 | 512x7x7 - 25088 | [512] | [512] | [512] | [512] | | | :relu | 512x7x7 - 25088 | 512x7x7 - 25088 | | | | | | | :convolutional | 512x7x7 - 25088 | 2048x7x7 - 100352 | [2048] | | | | [2048 512] | | :batch-normalization | 2048x7x7 - 100352 | 2048x7x7 - 100352 | [2048] | [2048] | [2048] | [2048] | | | :join | 2048x7x7 - 100352 | 2048x7x7 - 100352 | | | | | | | :relu | 2048x7x7 - 100352 | 2048x7x7 - 100352 | | | | | | | :max-pooling | 2048x7x7 - 100352 | 2048x1x1 - 2048 | | | | | | | :linear | 2048x1x1 - 2048 | 1x1x1000 - 1000 | [1000] | | | | [1000 2048] | | :softmax | 1x1x1000 - 1000 | 1x1x1000 - 1000 | | | | | | Parameter count: 25636712
Let's see what it looks like after we modify the network:
(def cats-dogs-network (cats-dogs/load-network "models/resnet50.nippy" :fc1000 cats-dogs/layers-to-add))
#'cats-dogs-cortex-redux.notebook/cats-dogs-network
(network/print-layer-summary cats-dogs-network (traverse/training-traversal cats-dogs-network))
| type | input | output | :bias | :means | :scale | :variances | :weights | |----------------------+---------------------+---------------------+--------+--------+--------+------------+-------------| | :convolutional | 3x224x224 - 150528 | 64x112x112 - 802816 | [64] | | | | [64 147] | | :batch-normalization | 64x112x112 - 802816 | 64x112x112 - 802816 | [64] | [64] | [64] | [64] | | | :relu | 64x112x112 - 802816 | 64x112x112 - 802816 | | | | | | | :max-pooling | 64x112x112 - 802816 | 64x55x55 - 193600 | | | | | | | :convolutional | 64x55x55 - 193600 | 64x55x55 - 193600 | [64] | | | | [64 64] | | :batch-normalization | 64x55x55 - 193600 | 64x55x55 - 193600 | [64] | [64] | [64] | [64] | | | :relu | 64x55x55 - 193600 | 64x55x55 - 193600 | | | | | | | :convolutional | 64x55x55 - 193600 | 64x55x55 - 193600 | [64] | | | | [64 576] | | :batch-normalization | 64x55x55 - 193600 | 64x55x55 - 193600 | [64] | [64] | [64] | [64] | | | :relu | 64x55x55 - 193600 | 64x55x55 - 193600 | | | | | | | :convolutional | 64x55x55 - 193600 | 256x55x55 - 774400 | [256] | | | | [256 64] | | :batch-normalization | 256x55x55 - 774400 | 256x55x55 - 774400 | [256] | [256] | [256] | [256] | | | :convolutional | 64x55x55 - 193600 | 256x55x55 - 774400 | [256] | | | | [256 64] | | :batch-normalization | 256x55x55 - 774400 | 256x55x55 - 774400 | [256] | [256] | [256] | [256] | | | :join | 256x55x55 - 774400 | 256x55x55 - 774400 | | | | | | | :relu | 256x55x55 - 774400 | 256x55x55 - 774400 | | | | | | | :convolutional | 256x55x55 - 774400 | 64x55x55 - 193600 | [64] | | | | [64 256] | | :batch-normalization | 64x55x55 - 193600 | 64x55x55 - 193600 | [64] | [64] | [64] | [64] | | | :relu | 64x55x55 - 193600 | 64x55x55 - 193600 | | | | | | | :convolutional | 64x55x55 - 193600 | 64x55x55 - 193600 | [64] | | | | [64 576] | | :batch-normalization | 64x55x55 - 193600 | 64x55x55 - 193600 | [64] | [64] | [64] | [64] | | | :relu | 64x55x55 - 193600 | 64x55x55 - 193600 | | | | | | | :convolutional | 64x55x55 - 193600 | 256x55x55 - 774400 | [256] | | | | [256 64] | | :batch-normalization | 256x55x55 - 774400 | 256x55x55 - 774400 | [256] | [256] | [256] | [256] | | | :join | 256x55x55 - 774400 | 256x55x55 - 774400 | | | | | | | :relu | 256x55x55 - 774400 | 256x55x55 - 774400 | | | | | | | :convolutional | 256x55x55 - 774400 | 64x55x55 - 193600 | [64] | | | | [64 256] | | :batch-normalization | 64x55x55 - 193600 | 64x55x55 - 193600 | [64] | [64] | [64] | [64] | | | :relu | 64x55x55 - 193600 | 64x55x55 - 193600 | | | | | | | :convolutional | 64x55x55 - 193600 | 64x55x55 - 193600 | [64] | | | | [64 576] | | :batch-normalization | 64x55x55 - 193600 | 64x55x55 - 193600 | [64] | [64] | [64] | [64] | | | :relu | 64x55x55 - 193600 | 64x55x55 - 193600 | | | | | | | :convolutional | 64x55x55 - 193600 | 256x55x55 - 774400 | [256] | | | | [256 64] | | :batch-normalization | 256x55x55 - 774400 | 256x55x55 - 774400 | [256] | [256] | [256] | [256] | | | :join | 256x55x55 - 774400 | 256x55x55 - 774400 | | | | | | | :relu | 256x55x55 - 774400 | 256x55x55 - 774400 | | | | | | | :convolutional | 256x55x55 - 774400 | 128x28x28 - 100352 | [128] | | | | [128 256] | | :batch-normalization | 128x28x28 - 100352 | 128x28x28 - 100352 | [128] | [128] | [128] | [128] | | | :relu | 128x28x28 - 100352 | 128x28x28 - 100352 | | | | | | | :convolutional | 128x28x28 - 100352 | 128x28x28 - 100352 | [128] | | | | [128 1152] | | :batch-normalization | 128x28x28 - 100352 | 128x28x28 - 100352 | [128] | [128] | [128] | [128] | | | :relu | 128x28x28 - 100352 | 128x28x28 - 100352 | | | | | | | :convolutional | 128x28x28 - 100352 | 512x28x28 - 401408 | [512] | | | | [512 128] | | :batch-normalization | 512x28x28 - 401408 | 512x28x28 - 401408 | [512] | [512] | [512] | [512] | | | :convolutional | 256x55x55 - 774400 | 512x28x28 - 401408 | [512] | | | | [512 256] | | :batch-normalization | 512x28x28 - 401408 | 512x28x28 - 401408 | [512] | [512] | [512] | [512] | | | :join | 512x28x28 - 401408 | 512x28x28 - 401408 | | | | | | | :relu | 512x28x28 - 401408 | 512x28x28 - 401408 | | | | | | | :convolutional | 512x28x28 - 401408 | 128x28x28 - 100352 | [128] | | | | [128 512] | | :batch-normalization | 128x28x28 - 100352 | 128x28x28 - 100352 | [128] | [128] | [128] | [128] | | | :relu | 128x28x28 - 100352 | 128x28x28 - 100352 | | | | | | | :convolutional | 128x28x28 - 100352 | 128x28x28 - 100352 | [128] | | | | [128 1152] | | :batch-normalization | 128x28x28 - 100352 | 128x28x28 - 100352 | [128] | [128] | [128] | [128] | | | :relu | 128x28x28 - 100352 | 128x28x28 - 100352 | | | | | | | :convolutional | 128x28x28 - 100352 | 512x28x28 - 401408 | [512] | | | | [512 128] | | :batch-normalization | 512x28x28 - 401408 | 512x28x28 - 401408 | [512] | [512] | [512] | [512] | | | :join | 512x28x28 - 401408 | 512x28x28 - 401408 | | | | | | | :relu | 512x28x28 - 401408 | 512x28x28 - 401408 | | | | | | | :convolutional | 512x28x28 - 401408 | 128x28x28 - 100352 | [128] | | | | [128 512] | | :batch-normalization | 128x28x28 - 100352 | 128x28x28 - 100352 | [128] | [128] | [128] | [128] | | | :relu | 128x28x28 - 100352 | 128x28x28 - 100352 | | | | | | | :convolutional | 128x28x28 - 100352 | 128x28x28 - 100352 | [128] | | | | [128 1152] | | :batch-normalization | 128x28x28 - 100352 | 128x28x28 - 100352 | [128] | [128] | [128] | [128] | | | :relu | 128x28x28 - 100352 | 128x28x28 - 100352 | | | | | | | :convolutional | 128x28x28 - 100352 | 512x28x28 - 401408 | [512] | | | | [512 128] | | :batch-normalization | 512x28x28 - 401408 | 512x28x28 - 401408 | [512] | [512] | [512] | [512] | | | :join | 512x28x28 - 401408 | 512x28x28 - 401408 | | | | | | | :relu | 512x28x28 - 401408 | 512x28x28 - 401408 | | | | | | | :convolutional | 512x28x28 - 401408 | 128x28x28 - 100352 | [128] | | | | [128 512] | | :batch-normalization | 128x28x28 - 100352 | 128x28x28 - 100352 | [128] | [128] | [128] | [128] | | | :relu | 128x28x28 - 100352 | 128x28x28 - 100352 | | | | | | | :convolutional | 128x28x28 - 100352 | 128x28x28 - 100352 | [128] | | | | [128 1152] | | :batch-normalization | 128x28x28 - 100352 | 128x28x28 - 100352 | [128] | [128] | [128] | [128] | | | :relu | 128x28x28 - 100352 | 128x28x28 - 100352 | | | | | | | :convolutional | 128x28x28 - 100352 | 512x28x28 - 401408 | [512] | | | | [512 128] | | :batch-normalization | 512x28x28 - 401408 | 512x28x28 - 401408 | [512] | [512] | [512] | [512] | | | :join | 512x28x28 - 401408 | 512x28x28 - 401408 | | | | | | | :relu | 512x28x28 - 401408 | 512x28x28 - 401408 | | | | | | | :convolutional | 512x28x28 - 401408 | 256x14x14 - 50176 | [256] | | | | [256 512] | | :batch-normalization | 256x14x14 - 50176 | 256x14x14 - 50176 | [256] | [256] | [256] | [256] | | | :relu | 256x14x14 - 50176 | 256x14x14 - 50176 | | | | | | | :convolutional | 256x14x14 - 50176 | 256x14x14 - 50176 | [256] | | | | [256 2304] | | :batch-normalization | 256x14x14 - 50176 | 256x14x14 - 50176 | [256] | [256] | [256] | [256] | | | :relu | 256x14x14 - 50176 | 256x14x14 - 50176 | | | | | | | :convolutional | 256x14x14 - 50176 | 1024x14x14 - 200704 | [1024] | | | | [1024 256] | | :batch-normalization | 1024x14x14 - 200704 | 1024x14x14 - 200704 | [1024] | [1024] | [1024] | [1024] | | | :convolutional | 512x28x28 - 401408 | 1024x14x14 - 200704 | [1024] | | | | [1024 512] | | :batch-normalization | 1024x14x14 - 200704 | 1024x14x14 - 200704 | [1024] | [1024] | [1024] | [1024] | | | :join | 1024x14x14 - 200704 | 1024x14x14 - 200704 | | | | | | | :relu | 1024x14x14 - 200704 | 1024x14x14 - 200704 | | | | | | | :convolutional | 1024x14x14 - 200704 | 256x14x14 - 50176 | [256] | | | | [256 1024] | | :batch-normalization | 256x14x14 - 50176 | 256x14x14 - 50176 | [256] | [256] | [256] | [256] | | | :relu | 256x14x14 - 50176 | 256x14x14 - 50176 | | | | | | | :convolutional | 256x14x14 - 50176 | 256x14x14 - 50176 | [256] | | | | [256 2304] | | :batch-normalization | 256x14x14 - 50176 | 256x14x14 - 50176 | [256] | [256] | [256] | [256] | | | :relu | 256x14x14 - 50176 | 256x14x14 - 50176 | | | | | | | :convolutional | 256x14x14 - 50176 | 1024x14x14 - 200704 | [1024] | | | | [1024 256] | | :batch-normalization | 1024x14x14 - 200704 | 1024x14x14 - 200704 | [1024] | [1024] | [1024] | [1024] | | | :join | 1024x14x14 - 200704 | 1024x14x14 - 200704 | | | | | | | :relu | 1024x14x14 - 200704 | 1024x14x14 - 200704 | | | | | | | :convolutional | 1024x14x14 - 200704 | 256x14x14 - 50176 | [256] | | | | [256 1024] | | :batch-normalization | 256x14x14 - 50176 | 256x14x14 - 50176 | [256] | [256] | [256] | [256] | | | :relu | 256x14x14 - 50176 | 256x14x14 - 50176 | | | | | | | :convolutional | 256x14x14 - 50176 | 256x14x14 - 50176 | [256] | | | | [256 2304] | | :batch-normalization | 256x14x14 - 50176 | 256x14x14 - 50176 | [256] | [256] | [256] | [256] | | | :relu | 256x14x14 - 50176 | 256x14x14 - 50176 | | | | | | | :convolutional | 256x14x14 - 50176 | 1024x14x14 - 200704 | [1024] | | | | [1024 256] | | :batch-normalization | 1024x14x14 - 200704 | 1024x14x14 - 200704 | [1024] | [1024] | [1024] | [1024] | | | :join | 1024x14x14 - 200704 | 1024x14x14 - 200704 | | | | | | | :relu | 1024x14x14 - 200704 | 1024x14x14 - 200704 | | | | | | | :convolutional | 1024x14x14 - 200704 | 256x14x14 - 50176 | [256] | | | | [256 1024] | | :batch-normalization | 256x14x14 - 50176 | 256x14x14 - 50176 | [256] | [256] | [256] | [256] | | | :relu | 256x14x14 - 50176 | 256x14x14 - 50176 | | | | | | | :convolutional | 256x14x14 - 50176 | 256x14x14 - 50176 | [256] | | | | [256 2304] | | :batch-normalization | 256x14x14 - 50176 | 256x14x14 - 50176 | [256] | [256] | [256] | [256] | | | :relu | 256x14x14 - 50176 | 256x14x14 - 50176 | | | | | | | :convolutional | 256x14x14 - 50176 | 1024x14x14 - 200704 | [1024] | | | | [1024 256] | | :batch-normalization | 1024x14x14 - 200704 | 1024x14x14 - 200704 | [1024] | [1024] | [1024] | [1024] | | | :join | 1024x14x14 - 200704 | 1024x14x14 - 200704 | | | | | | | :relu | 1024x14x14 - 200704 | 1024x14x14 - 200704 | | | | | | | :convolutional | 1024x14x14 - 200704 | 256x14x14 - 50176 | [256] | | | | [256 1024] | | :batch-normalization | 256x14x14 - 50176 | 256x14x14 - 50176 | [256] | [256] | [256] | [256] | | | :relu | 256x14x14 - 50176 | 256x14x14 - 50176 | | | | | | | :convolutional | 256x14x14 - 50176 | 256x14x14 - 50176 | [256] | | | | [256 2304] | | :batch-normalization | 256x14x14 - 50176 | 256x14x14 - 50176 | [256] | [256] | [256] | [256] | | | :relu | 256x14x14 - 50176 | 256x14x14 - 50176 | | | | | | | :convolutional | 256x14x14 - 50176 | 1024x14x14 - 200704 | [1024] | | | | [1024 256] | | :batch-normalization | 1024x14x14 - 200704 | 1024x14x14 - 200704 | [1024] | [1024] | [1024] | [1024] | | | :join | 1024x14x14 - 200704 | 1024x14x14 - 200704 | | | | | | | :relu | 1024x14x14 - 200704 | 1024x14x14 - 200704 | | | | | | | :convolutional | 1024x14x14 - 200704 | 256x14x14 - 50176 | [256] | | | | [256 1024] | | :batch-normalization | 256x14x14 - 50176 | 256x14x14 - 50176 | [256] | [256] | [256] | [256] | | | :relu | 256x14x14 - 50176 | 256x14x14 - 50176 | | | | | | | :convolutional | 256x14x14 - 50176 | 256x14x14 - 50176 | [256] | | | | [256 2304] | | :batch-normalization | 256x14x14 - 50176 | 256x14x14 - 50176 | [256] | [256] | [256] | [256] | | | :relu | 256x14x14 - 50176 | 256x14x14 - 50176 | | | | | | | :convolutional | 256x14x14 - 50176 | 1024x14x14 - 200704 | [1024] | | | | [1024 256] | | :batch-normalization | 1024x14x14 - 200704 | 1024x14x14 - 200704 | [1024] | [1024] | [1024] | [1024] | | | :join | 1024x14x14 - 200704 | 1024x14x14 - 200704 | | | | | | | :relu | 1024x14x14 - 200704 | 1024x14x14 - 200704 | | | | | | | :convolutional | 1024x14x14 - 200704 | 512x7x7 - 25088 | [512] | | | | [512 1024] | | :batch-normalization | 512x7x7 - 25088 | 512x7x7 - 25088 | [512] | [512] | [512] | [512] | | | :relu | 512x7x7 - 25088 | 512x7x7 - 25088 | | | | | | | :convolutional | 512x7x7 - 25088 | 512x7x7 - 25088 | [512] | | | | [512 4608] | | :batch-normalization | 512x7x7 - 25088 | 512x7x7 - 25088 | [512] | [512] | [512] | [512] | | | :relu | 512x7x7 - 25088 | 512x7x7 - 25088 | | | | | | | :convolutional | 512x7x7 - 25088 | 2048x7x7 - 100352 | [2048] | | | | [2048 512] | | :batch-normalization | 2048x7x7 - 100352 | 2048x7x7 - 100352 | [2048] | [2048] | [2048] | [2048] | | | :convolutional | 1024x14x14 - 200704 | 2048x7x7 - 100352 | [2048] | | | | [2048 1024] | | :batch-normalization | 2048x7x7 - 100352 | 2048x7x7 - 100352 | [2048] | [2048] | [2048] | [2048] | | | :join | 2048x7x7 - 100352 | 2048x7x7 - 100352 | | | | | | | :relu | 2048x7x7 - 100352 | 2048x7x7 - 100352 | | | | | | | :convolutional | 2048x7x7 - 100352 | 512x7x7 - 25088 | [512] | | | | [512 2048] | | :batch-normalization | 512x7x7 - 25088 | 512x7x7 - 25088 | [512] | [512] | [512] | [512] | | | :relu | 512x7x7 - 25088 | 512x7x7 - 25088 | | | | | | | :convolutional | 512x7x7 - 25088 | 512x7x7 - 25088 | [512] | | | | [512 4608] | | :batch-normalization | 512x7x7 - 25088 | 512x7x7 - 25088 | [512] | [512] | [512] | [512] | | | :relu | 512x7x7 - 25088 | 512x7x7 - 25088 | | | | | | | :convolutional | 512x7x7 - 25088 | 2048x7x7 - 100352 | [2048] | | | | [2048 512] | | :batch-normalization | 2048x7x7 - 100352 | 2048x7x7 - 100352 | [2048] | [2048] | [2048] | [2048] | | | :join | 2048x7x7 - 100352 | 2048x7x7 - 100352 | | | | | | | :relu | 2048x7x7 - 100352 | 2048x7x7 - 100352 | | | | | | | :convolutional | 2048x7x7 - 100352 | 512x7x7 - 25088 | [512] | | | | [512 2048] | | :batch-normalization | 512x7x7 - 25088 | 512x7x7 - 25088 | [512] | [512] | [512] | [512] | | | :relu | 512x7x7 - 25088 | 512x7x7 - 25088 | | | | | | | :convolutional | 512x7x7 - 25088 | 512x7x7 - 25088 | [512] | | | | [512 4608] | | :batch-normalization | 512x7x7 - 25088 | 512x7x7 - 25088 | [512] | [512] | [512] | [512] | | | :relu | 512x7x7 - 25088 | 512x7x7 - 25088 | | | | | | | :convolutional | 512x7x7 - 25088 | 2048x7x7 - 100352 | [2048] | | | | [2048 512] | | :batch-normalization | 2048x7x7 - 100352 | 2048x7x7 - 100352 | [2048] | [2048] | [2048] | [2048] | | | :join | 2048x7x7 - 100352 | 2048x7x7 - 100352 | | | | | | | :relu | 2048x7x7 - 100352 | 2048x7x7 - 100352 | | | | | | | :max-pooling | 2048x7x7 - 100352 | 2048x1x1 - 2048 | | | | | | | :linear | 2048x1x1 - 2048 | 1x1x2 - 2 | [2] | | | | [2 2048] | | :softmax | 1x1x2 - 2 | 1x1x2 - 2 | | | | | | Parameter count: 23591810
Notice that the last linear and softmax layers have been replaced. Another point is that the weights are frozen except for the last 2 layers
(clojure.pprint/pprint (->> (map (fn [[name value]] {:name name :non-trainable? (:non-trainable? value)})
(get-in cats-dogs-network [:compute-graph :nodes]))
(sort-by :non-trainable?)))
({:name :labels, :non-trainable? nil} {:name :softmax-loss-1, :non-trainable? nil} {:name :fc2, :non-trainable? nil} {:name :res2a_branch2b, :non-trainable? true} {:name :bn4f_branch2b, :non-trainable? true} {:name :bn4e_branch2b, :non-trainable? true} {:name :activation_2, :non-trainable? true} {:name :res4f_branch2a, :non-trainable? true} {:name :activation_29, :non-trainable? true} {:name :add_16, :non-trainable? true} {:name :activation_48, :non-trainable? true} {:name :res4e_branch2b, :non-trainable? true} {:name :activation_6, :non-trainable? true} {:name :add_9, :non-trainable? true} {:name :res2a_branch2a, :non-trainable? true} {:name :res4d_branch2a, :non-trainable? true} {:name :res2c_branch2b, :non-trainable? true} {:name :bn5c_branch2a, :non-trainable? true} {:name :add_8, :non-trainable? true} {:name :activation_31-split, :non-trainable? true} {:name :res5c_branch2c, :non-trainable? true} {:name :add_7, :non-trainable? true} {:name :add_1, :non-trainable? true} {:name :bn2b_branch2a, :non-trainable? true} {:name :bn4c_branch2a, :non-trainable? true} {:name :res5a_branch2c, :non-trainable? true} {:name :activation_18, :non-trainable? true} {:name :bn5a_branch2c, :non-trainable? true} {:name :res3d_branch2a, :non-trainable? true} {:name :add_11, :non-trainable? true} {:name :res4a_branch2a, :non-trainable? true} {:name :add_6, :non-trainable? true} {:name :res3b_branch2b, :non-trainable? true} {:name :bn3d_branch2a, :non-trainable? true} {:name :res2b_branch2b, :non-trainable? true} {:name :res4e_branch2a, :non-trainable? true} {:name :activation_7, :non-trainable? true} {:name :bn_conv1, :non-trainable? true} {:name :bn4b_branch2a, :non-trainable? true} {:name :activation_37, :non-trainable? true} {:name :bn4f_branch2c, :non-trainable? true} {:name :res3a_branch2a, :non-trainable? true} {:name :activation_35, :non-trainable? true} {:name :bn2b_branch2b, :non-trainable? true} {:name :activation_15, :non-trainable? true} {:name :activation_24, :non-trainable? true} {:name :bn5b_branch2b, :non-trainable? true} {:name :bn3b_branch2c, :non-trainable? true} {:name :activation_47, :non-trainable? true} {:name :activation_40-split, :non-trainable? true} {:name :res4c_branch2c, :non-trainable? true} {:name :activation_8, :non-trainable? true} {:name :activation_13, :non-trainable? true} {:name :activation_26, :non-trainable? true} {:name :activation_45, :non-trainable? true} {:name :activation_19-split, :non-trainable? true} {:name :bn5b_branch2a, :non-trainable? true} {:name :res4c_branch2b, :non-trainable? true} {:name :activation_12, :non-trainable? true} {:name :activation_11, :non-trainable? true} {:name :activation_43, :non-trainable? true} {:name :res3b_branch2a, :non-trainable? true} {:name :activation_21, :non-trainable? true} {:name :res3c_branch2b, :non-trainable? true} {:name :res2a_branch2c, :non-trainable? true} {:name :bn4a_branch2a, :non-trainable? true} {:name :bn5b_branch2c, :non-trainable? true} {:name :bn4d_branch2a, :non-trainable? true} {:name :bn4a_branch2c, :non-trainable? true} {:name :bn4d_branch2b, :non-trainable? true} {:name :activation_19, :non-trainable? true} {:name :activation_44, :non-trainable? true} {:name :bn4c_branch2b, :non-trainable? true} {:name :activation_27, :non-trainable? true} {:name :activation_31, :non-trainable? true} {:name :add_4, :non-trainable? true} {:name :res5c_branch2a, :non-trainable? true} {:name :activation_34-split, :non-trainable? true} {:name :res4c_branch2a, :non-trainable? true} {:name :res4a_branch2c, :non-trainable? true} {:name :bn4a_branch2b, :non-trainable? true} {:name :activation_4, :non-trainable? true} {:name :activation_34, :non-trainable? true} {:name :bn2a_branch2b, :non-trainable? true} {:name :res4b_branch2a, :non-trainable? true} {:name :res3c_branch2a, :non-trainable? true} {:name :res3d_branch2c, :non-trainable? true} {:name :res5b_branch2b, :non-trainable? true} {:name :res4f_branch2b, :non-trainable? true} {:name :activation_39, :non-trainable? true} {:name :activation_10-split, :non-trainable? true} {:name :res3c_branch2c, :non-trainable? true} {:name :res3d_branch2b, :non-trainable? true} {:name :bn4b_branch2b, :non-trainable? true} {:name :res2b_branch2c, :non-trainable? true} {:name :bn4e_branch2a, :non-trainable? true} {:name :conv1, :non-trainable? true} {:name :activation_7-split, :non-trainable? true} {:name :activation_49, :non-trainable? true} {:name :bn5a_branch2a, :non-trainable? true} {:name :bn4e_branch2c, :non-trainable? true} {:name :bn3a_branch2c, :non-trainable? true} {:name :bn4b_branch2c, :non-trainable? true} {:name :activation_33, :non-trainable? true} {:name :bn5c_branch2c, :non-trainable? true} {:name :add_12, :non-trainable? true} {:name :activation_10, :non-trainable? true} {:name :avg_pool, :non-trainable? true} {:name :bn4c_branch2c, :non-trainable? true} {:name :activation_1, :non-trainable? true} {:name :activation_22-split, :non-trainable? true} {:name :activation_20, :non-trainable? true} {:name :activation_25, :non-trainable? true} {:name :res5c_branch2b, :non-trainable? true} {:name :activation_40, :non-trainable? true} {:name :activation_16-split, :non-trainable? true} {:name :add_3, :non-trainable? true} {:name :add_2, :non-trainable? true} {:name :bn4a_branch1, :non-trainable? true} {:name :bn3b_branch2b, :non-trainable? true} {:name :activation_28-split, :non-trainable? true} {:name :activation_46-split, :non-trainable? true} {:name :res4d_branch2c, :non-trainable? true} {:name :res4a_branch1, :non-trainable? true} {:name :bn3c_branch2a, :non-trainable? true} {:name :res4b_branch2b, :non-trainable? true} {:name :activation_23, :non-trainable? true} {:name :bn2c_branch2a, :non-trainable? true} {:name :activation_3, :non-trainable? true} {:name :res2b_branch2a, :non-trainable? true} {:name :activation_16, :non-trainable? true} {:name :res5a_branch1, :non-trainable? true} {:name :res3a_branch2c, :non-trainable? true} {:name :bn5a_branch1, :non-trainable? true} {:name :activation_17, :non-trainable? true} {:name :activation_46, :non-trainable? true} {:name :activation_41, :non-trainable? true} {:name :bn3a_branch1, :non-trainable? true} {:name :bn2a_branch2a, :non-trainable? true} {:name :res5b_branch2c, :non-trainable? true} {:name :bn5c_branch2b, :non-trainable? true} {:name :activation_43-split, :non-trainable? true} {:name :bn3a_branch2a, :non-trainable? true} {:name :res2c_branch2c, :non-trainable? true} {:name :activation_28, :non-trainable? true} {:name :bn3c_branch2c, :non-trainable? true} {:name :activation_38, :non-trainable? true} {:name :activation_42, :non-trainable? true} {:name :res4b_branch2c, :non-trainable? true} {:name :res5b_branch2a, :non-trainable? true} {:name :bn2a_branch2c, :non-trainable? true} {:name :bn2b_branch2c, :non-trainable? true} {:name :bn2a_branch1, :non-trainable? true} {:name :activation_37-split, :non-trainable? true} {:name :res4d_branch2b, :non-trainable? true} {:name :res3a_branch2b, :non-trainable? true} {:name :bn5a_branch2b, :non-trainable? true} {:name :activation_13-split, :non-trainable? true} {:name :res2c_branch2a, :non-trainable? true} {:name :res3b_branch2c, :non-trainable? true} {:name :activation_32, :non-trainable? true} {:name :bn3c_branch2b, :non-trainable? true} {:name :max_pooling2d_1-split, :non-trainable? true} {:name :res3a_branch1, :non-trainable? true} {:name :bn4f_branch2a, :non-trainable? true} {:name :bn2c_branch2c, :non-trainable? true} {:name :activation_14, :non-trainable? true} {:name :res5a_branch2b, :non-trainable? true} {:name :add_15, :non-trainable? true} {:name :res4f_branch2c, :non-trainable? true} {:name :res4a_branch2b, :non-trainable? true} {:name :add_14, :non-trainable? true} {:name :activation_25-split, :non-trainable? true} {:name :activation_4-split, :non-trainable? true} {:name :activation_22, :non-trainable? true} {:name :bn3d_branch2b, :non-trainable? true} {:name :res5a_branch2a, :non-trainable? true} {:name :res4e_branch2c, :non-trainable? true} {:name :res2a_branch1, :non-trainable? true} {:name :add_10, :non-trainable? true} {:name :bn4d_branch2c, :non-trainable? true} {:name :add_13, :non-trainable? true} {:name :bn3b_branch2a, :non-trainable? true} {:name :activation_5, :non-trainable? true} {:name :bn3a_branch2b, :non-trainable? true} {:name :bn3d_branch2c, :non-trainable? true} {:name :activation_36, :non-trainable? true} {:name :max_pooling2d_1, :non-trainable? true} {:name :activation_9, :non-trainable? true} {:name :activation_30, :non-trainable? true} {:name :data, :non-trainable? true} {:name :add_5, :non-trainable? true} {:name :bn2c_branch2b, :non-trainable? true})
Now we're at a point where we can actually do the training. There is a function called train
that takes a batch size. If you are running on your computer, and you have memory problems you can try decreasing the batch size or running the core code as an uber jar to do the training. The default is set to a batch size of 32. Another option is running on a AWS P2 compute instance.
For me to be able to run on my old mac, I need to run the uberjar.
If you want to do the uber jar:
lein uberjar
java -jar target/cats-dogs-cortex-redux.jar
Using the GPU on my mac it takes approximately 6 minutes to train 1 epoch Note: 1 epoch of fine tuning is all we need
Loss for epoch 1: (current) 0.05875186542016347 (best) null
Saving network to trained-network.nippy
The key point is that it saved the fine tuned network to trained-network.nippy
Note that we are only going to do 1 epoch of fine tuning
(ns cats-dogs-cortex-redux.notebook
(:require [cats-dogs-cortex-redux.core :as cats-dogs]
[clojure.java.shell :as shell]))
Now we can test out things a bit. There is a label-one
function that grabs a random image and classifies it
Note a window will popup with a the dog or cat picture in it
(cats-dogs/label-one)
{:answer "dog", :guess {:prob 0.9978577494621277, :class "dog"}}
You will need to do a bit more setup for this. First, you need to get the Kaggle test images for classification. There are 12500
of these in the test.zip
file from the site. Under the data
directory, create a new directory called kaggle-test
. Now unzip
the contents of test.zip inside that folder. The full directory with all the test images should now be:
data/kaggle-test/test
This step takes a long time and you might have to tweak the batch size again depending on your memory. There are 12500 predications to be made. The main logic for this is in function called (kaggle-results batch-size)
. It will take a long time to run.
It will print the results as it goes along to the kaggle-results.csv
file. If you want to check progress you can do wc -l kaggle-results.csv
(ns cats-dogs-cortex-redux.notebook
(:require [cats-dogs-cortex-redux.core :as cats-dogs]
[clojure.java.shell :as shell]))
(cats-dogs/kaggle-results 100)
.............................................................................................................................
clojure.lang.LazySeq@f331f79f
Done! It took me about 28 minutes locally.
Now you can take the kaggle-results.csv
file and upload it to the competition and check your results!
Mine was 0.10357
Whoo!