%matplotlib inline
import os
os.environ['THEANO_FLAGS']='device=gpu0'
import matplotlib
import numpy as np
np.random.seed(123)
import matplotlib.pyplot as plt
import lasagne
import theano
import theano.tensor as T
conv = lasagne.layers.Conv2DLayer
pool = lasagne.layers.MaxPool2DLayer
NUM_EPOCHS = 500
BATCH_SIZE = 256
LEARNING_RATE = 0.001
DIM = 60
NUM_CLASSES = 10
mnist_cluttered = "mnist_cluttered_60x60_6distortions.npz"
Using gpu device 0: Graphics Device
We use lasagne to classify cluttered MNIST digits using the spatial transformer network introduced in [1]. The spatial Transformer Network applies a learned affine transformation to its input.
We test the spatial transformer network using cluttered MNIST data.
Download the data (41 mb) with:
!wget -N https://s3.amazonaws.com/lasagne/recipes/datasets/mnist_cluttered_60x60_6distortions.npz
--2015-08-19 14:52:08-- https://s3.amazonaws.com/lasagne/recipes/datasets/mnist_cluttered_60x60_6distortions.npz Resolving s3.amazonaws.com... 54.231.48.99 Connecting to s3.amazonaws.com|54.231.48.99|:443... connected. HTTP request sent, awaiting response... 200 OK Length: 43046126 (41M) [application/octet-stream] Server file no newer than local file 'mnist_cluttered_60x60_6distortions.npz' -- not retrieving.
def load_data():
data = np.load(mnist_cluttered)
X_train, y_train = data['x_train'], np.argmax(data['y_train'], axis=-1)
X_valid, y_valid = data['x_valid'], np.argmax(data['y_valid'], axis=-1)
X_test, y_test = data['x_test'], np.argmax(data['y_test'], axis=-1)
# reshape for convolutions
X_train = X_train.reshape((X_train.shape[0], 1, DIM, DIM))
X_valid = X_valid.reshape((X_valid.shape[0], 1, DIM, DIM))
X_test = X_test.reshape((X_test.shape[0], 1, DIM, DIM))
print "Train samples:", X_train.shape
print "Validation samples:", X_valid.shape
print "Test samples:", X_test.shape
return dict(
X_train=lasagne.utils.floatX(X_train),
y_train=y_train.astype('int32'),
X_valid=lasagne.utils.floatX(X_valid),
y_valid=y_valid.astype('int32'),
X_test=lasagne.utils.floatX(X_test),
y_test=y_test.astype('int32'),
num_examples_train=X_train.shape[0],
num_examples_valid=X_valid.shape[0],
num_examples_test=X_test.shape[0],
input_height=X_train.shape[2],
input_width=X_train.shape[3],
output_dim=10,)
data = load_data()
Train samples: (50000, 1, 60, 60) Validation samples: (10000, 1, 60, 60) Test samples: (10000, 1, 60, 60)
plt.figure(figsize=(7,7))
plt.imshow(data['X_train'][101].reshape(DIM, DIM), cmap='gray', interpolation='none')
plt.title('Cluttered MNIST', fontsize=20)
plt.axis('off')
plt.show()
We use a model where the localization network is a two layer convolution network which operates directly on the image input. The output from the localization network is a 6 dimensional vector specifying the parameters in the affine transformation.
The localization feeds into the transformer layer which applies the transformation to the image input. In our setup the transformer layer downsamples the input by a factor 3.
Finally a 2 layer convolution layer and 2 fully connected layers calculates the output probabilities.
The model
Input -> localization_network -> TransformerLayer -> output_network -> predictions
| |
>--------------------------------^
def build_model(input_width, input_height, output_dim,
batch_size=BATCH_SIZE):
ini = lasagne.init.HeUniform()
l_in = lasagne.layers.InputLayer(shape=(None, 1, input_width, input_height),)
# Localization network
b = np.zeros((2, 3), dtype=theano.config.floatX)
b[0, 0] = 1
b[1, 1] = 1
b = b.flatten()
loc_l1 = pool(l_in, pool_size=(2, 2))
loc_l2 = conv(
loc_l1, num_filters=20, filter_size=(5, 5), W=ini)
loc_l3 = pool(loc_l2, pool_size=(2, 2))
loc_l4 = conv(loc_l3, num_filters=20, filter_size=(5, 5), W=ini)
loc_l5 = lasagne.layers.DenseLayer(
loc_l4, num_units=50, W=lasagne.init.HeUniform('relu'))
loc_out = lasagne.layers.DenseLayer(
loc_l5, num_units=6, b=b, W=lasagne.init.Constant(0.0),
nonlinearity=lasagne.nonlinearities.identity)
# Transformer network
l_trans1 = lasagne.layers.TransformerLayer(l_in, loc_out, downsample_factor=3.0)
print "Transformer network output shape: ", l_trans1.output_shape
# Classification network
class_l1 = conv(
l_trans1,
num_filters=32,
filter_size=(3, 3),
nonlinearity=lasagne.nonlinearities.rectify,
W=ini,
)
class_l2 = pool(class_l1, pool_size=(2, 2))
class_l3 = conv(
class_l2,
num_filters=32,
filter_size=(3, 3),
nonlinearity=lasagne.nonlinearities.rectify,
W=ini,
)
class_l4 = pool(class_l3, pool_size=(2, 2))
class_l5 = lasagne.layers.DenseLayer(
class_l4,
num_units=256,
nonlinearity=lasagne.nonlinearities.rectify,
W=ini,
)
l_out = lasagne.layers.DenseLayer(
class_l5,
num_units=output_dim,
nonlinearity=lasagne.nonlinearities.softmax,
W=ini,
)
return l_out, l_trans1
model, l_transform = build_model(DIM, DIM, NUM_CLASSES)
model_params = lasagne.layers.get_all_params(model, trainable=True)
Transformer network output shape: (None, 1, 20, 20)
X = T.tensor4()
y = T.ivector()
# training output
output_train = lasagne.layers.get_output(model, X, deterministic=False)
# evaluation output. Also includes output of transform for plotting
output_eval, transform_eval = lasagne.layers.get_output([model, l_transform], X, deterministic=True)
sh_lr = theano.shared(lasagne.utils.floatX(LEARNING_RATE))
cost = T.mean(T.nnet.categorical_crossentropy(output_train, y))
updates = lasagne.updates.adam(cost, model_params, learning_rate=sh_lr)
train = theano.function([X, y], [cost, output_train], updates=updates)
eval = theano.function([X], [output_eval, transform_eval])
def train_epoch(X, y):
num_samples = X.shape[0]
num_batches = int(np.ceil(num_samples / float(BATCH_SIZE)))
costs = []
correct = 0
for i in range(num_batches):
idx = range(i*BATCH_SIZE, np.minimum((i+1)*BATCH_SIZE, num_samples))
X_batch = X[idx]
y_batch = y[idx]
cost_batch, output_train = train(X_batch, y_batch)
costs += [cost_batch]
preds = np.argmax(output_train, axis=-1)
correct += np.sum(y_batch == preds)
return np.mean(costs), correct / float(num_samples)
def eval_epoch(X, y):
output_eval, transform_eval = eval(X)
preds = np.argmax(output_eval, axis=-1)
acc = np.mean(preds == y)
return acc, transform_eval
valid_accs, train_accs, test_accs = [], [], []
try:
for n in range(NUM_EPOCHS):
train_cost, train_acc = train_epoch(data['X_train'], data['y_train'])
valid_acc, valid_trainsform = eval_epoch(data['X_valid'], data['y_valid'])
test_acc, test_transform = eval_epoch(data['X_test'], data['y_test'])
valid_accs += [valid_acc]
test_accs += [test_acc]
train_accs += [train_acc]
if (n+1) % 20 == 0:
new_lr = sh_lr.get_value() * 0.7
print "New LR:", new_lr
sh_lr.set_value(lasagne.utils.floatX(new_lr))
print "Epoch {0}: Train cost {1}, Train acc {2}, val acc {3}, test acc {4}".format(
n, train_cost, train_acc, valid_acc, test_acc)
except KeyboardInterrupt:
pass
Epoch 0: Train cost 1.72300577164, Train acc 0.38824, val acc 0.6114, test acc 0.6087 Epoch 1: Train cost 0.867130100727, Train acc 0.71758, val acc 0.7745, test acc 0.7759 Epoch 2: Train cost 0.618825733662, Train acc 0.79848, val acc 0.8199, test acc 0.827 Epoch 3: Train cost 0.475057393312, Train acc 0.8489, val acc 0.8602, test acc 0.8613 Epoch 4: Train cost 0.369837403297, Train acc 0.88208, val acc 0.8697, test acc 0.8723 Epoch 5: Train cost 0.336995840073, Train acc 0.89126, val acc 0.8957, test acc 0.8974 Epoch 6: Train cost 0.288021206856, Train acc 0.90742, val acc 0.9005, test acc 0.8993 Epoch 7: Train cost 0.260697960854, Train acc 0.915, val acc 0.9081, test acc 0.9091 Epoch 8: Train cost 0.235620766878, Train acc 0.92484, val acc 0.917, test acc 0.9214 Epoch 9: Train cost 0.232491567731, Train acc 0.9245, val acc 0.9205, test acc 0.921 Epoch 10: Train cost 0.214803680778, Train acc 0.92916, val acc 0.9249, test acc 0.926 Epoch 11: Train cost 0.191879570484, Train acc 0.93728, val acc 0.9306, test acc 0.9317 Epoch 12: Train cost 0.187945634127, Train acc 0.93854, val acc 0.9365, test acc 0.937 Epoch 13: Train cost 0.177504748106, Train acc 0.94238, val acc 0.9329, test acc 0.933 Epoch 14: Train cost 0.161393344402, Train acc 0.9479, val acc 0.9246, test acc 0.9269 Epoch 15: Train cost 0.158181488514, Train acc 0.9482, val acc 0.9353, test acc 0.9382 Epoch 16: Train cost 0.162177875638, Train acc 0.94768, val acc 0.9399, test acc 0.9385 Epoch 17: Train cost 0.150974154472, Train acc 0.95074, val acc 0.9417, test acc 0.944 Epoch 18: Train cost 0.13878442347, Train acc 0.9546, val acc 0.9514, test acc 0.9481 New LR: 0.000700000033248 Epoch 19: Train cost 0.139381811023, Train acc 0.95302, val acc 0.9465, test acc 0.9477 Epoch 20: Train cost 0.115818083286, Train acc 0.96186, val acc 0.9498, test acc 0.9515 Epoch 21: Train cost 0.10844618082, Train acc 0.96364, val acc 0.9537, test acc 0.9544 Epoch 22: Train cost 0.104168988764, Train acc 0.9651, val acc 0.95, test acc 0.9522 Epoch 23: Train cost 0.100386917591, Train acc 0.96664, val acc 0.9523, test acc 0.9533 Epoch 24: Train cost 0.101429723203, Train acc 0.9666, val acc 0.9516, test acc 0.9557 Epoch 25: Train cost 0.0968987718225, Train acc 0.96804, val acc 0.9523, test acc 0.9556 Epoch 26: Train cost 0.0905688554049, Train acc 0.97016, val acc 0.955, test acc 0.9533 Epoch 27: Train cost 0.0892679914832, Train acc 0.97024, val acc 0.9574, test acc 0.9537 Epoch 28: Train cost 0.0790596753359, Train acc 0.9733, val acc 0.956, test acc 0.9577 Epoch 29: Train cost 0.0846520811319, Train acc 0.97228, val acc 0.9586, test acc 0.9575 Epoch 30: Train cost 0.0861563980579, Train acc 0.9711, val acc 0.9553, test acc 0.9579 Epoch 31: Train cost 0.084160938859, Train acc 0.9713, val acc 0.9574, test acc 0.9565 Epoch 32: Train cost 0.0740946382284, Train acc 0.97538, val acc 0.9583, test acc 0.9568 Epoch 33: Train cost 0.0750161111355, Train acc 0.97476, val acc 0.9522, test acc 0.9558 Epoch 34: Train cost 0.0719307512045, Train acc 0.97592, val acc 0.9534, test acc 0.9601 Epoch 35: Train cost 0.0688360854983, Train acc 0.97742, val acc 0.9568, test acc 0.9578 Epoch 36: Train cost 0.0659850463271, Train acc 0.97732, val acc 0.9586, test acc 0.9602 Epoch 37: Train cost 0.0669036284089, Train acc 0.97736, val acc 0.9606, test acc 0.9581 Epoch 38: Train cost 0.0615548193455, Train acc 0.9792, val acc 0.9584, test acc 0.9538 New LR: 0.000490000023274 Epoch 39: Train cost 0.0617390647531, Train acc 0.9795, val acc 0.9585, test acc 0.9574 Epoch 40: Train cost 0.0535897053778, Train acc 0.9818, val acc 0.9563, test acc 0.9582 Epoch 41: Train cost 0.0471548065543, Train acc 0.98434, val acc 0.9622, test acc 0.9613 Epoch 42: Train cost 0.0408403426409, Train acc 0.98648, val acc 0.9635, test acc 0.9624 Epoch 43: Train cost 0.0405819378793, Train acc 0.98642, val acc 0.9636, test acc 0.9619 Epoch 44: Train cost 0.0374028384686, Train acc 0.98754, val acc 0.9606, test acc 0.9614 Epoch 45: Train cost 0.0365789830685, Train acc 0.98828, val acc 0.9591, test acc 0.9574 Epoch 46: Train cost 0.0347327440977, Train acc 0.98848, val acc 0.962, test acc 0.9613
plt.figure(figsize=(9,9))
plt.plot(1-np.array(train_accs), label='Training Error')
plt.plot(1-np.array(valid_accs), label='Validation Error')
plt.legend(fontsize=20)
plt.xlabel('Epoch', fontsize=20)
plt.ylabel('Error', fontsize=20)
plt.show()
plt.figure(figsize=(7,14))
for i in range(3):
plt.subplot(321+i*2)
plt.imshow(data['X_test'][i].reshape(DIM, DIM), cmap='gray', interpolation='none')
if i == 0:
plt.title('Original 60x60', fontsize=20)
plt.axis('off')
plt.subplot(322+i*2)
plt.imshow(test_transform[i].reshape(DIM//3, DIM//3), cmap='gray', interpolation='none')
if i == 0:
plt.title('Transformed 20x20', fontsize=20)
plt.axis('off')
plt.tight_layout()
[1] Jaderberg, Max, et al. "Spatial Transformer Networks." arXiv preprint arXiv:1506.02025 (2015).