%matplotlib inline
import math,sys,os,numpy as np
from numpy.linalg import norm
from PIL import Image
from matplotlib import pyplot as plt, rcParams, rc
from scipy.ndimage import imread
from skimage.measure import block_reduce
import pickle as pickle
from scipy.ndimage.filters import correlate, convolve
rc('animation', html='html5')
rcParams['figure.figsize'] = 3, 6
%precision 4
np.set_printoptions(precision=4, linewidth=100)
def plots(ims, interp=False, titles=None):
ims=np.array(ims)
mn,mx=ims.min(),ims.max()
f = plt.figure(figsize=(12,24))
for i in range(len(ims)):
sp=f.add_subplot(1, len(ims), i+1)
if not titles is None: sp.set_title(titles[i], fontsize=18)
plt.imshow(ims[i], interpolation=None if interp else 'none', vmin=mn,vmax=mx)
def plot(im, interp=False):
f = plt.figure(figsize=(3,6), frameon=True)
# plt.show(im)
plt.imshow(im, interpolation=None if interp else 'none')
plt.gray()
plt.close()
from sklearn.datasets import fetch_mldata
mnist = fetch_mldata('MNIST original')
mnist.keys()
dict_keys(['DESCR', 'COL_NAMES', 'target', 'data'])
mnist['data'].shape, mnist['target'].shape
((70000, 784), (70000,))
images = np.reshape(mnist['data'], (70000, 28, 28))
labels = mnist['target'].astype(int)
n=len(images)
images.shape, labels.shape
((70000, 28, 28), (70000,))
images = images/255
plot(images[0])
labels[0]
0
plots(images[:5], titles=labels[:5])
we can zoom in on part of the image
plot(images[0,0:14, 8:22])
We will look at how to create an Edge detector:
top=[[-1,-1,-1],
[ 1, 1, 1],
[ 0, 0, 0]]
plot(top)
dims = np.index_exp[10:28:1,3:13]
images[0][dims]
array([[ 0. , 0. , 0. , 0. , 0. , 0. , 0.1882, 0.9333, 0.9882, 0.9882], [ 0. , 0. , 0. , 0. , 0. , 0.149 , 0.6471, 0.9922, 0.9137, 0.8157], [ 0. , 0. , 0. , 0. , 0.0275, 0.698 , 0.9882, 0.9412, 0.2784, 0.0745], [ 0. , 0. , 0. , 0. , 0.2235, 0.9882, 0.9882, 0.2471, 0. , 0. ], [ 0. , 0. , 0. , 0. , 0.7765, 0.9922, 0.7451, 0. , 0. , 0. ], [ 0. , 0. , 0. , 0.298 , 0.9647, 0.9882, 0.4392, 0. , 0. , 0. ], [ 0. , 0. , 0. , 0.3333, 0.9882, 0.902 , 0.098 , 0. , 0. , 0. ], [ 0. , 0. , 0. , 0.3333, 0.9882, 0.8745, 0. , 0. , 0. , 0. ], [ 0. , 0. , 0. , 0.3333, 0.9882, 0.5686, 0. , 0. , 0. , 0. ], [ 0. , 0. , 0. , 0.3373, 0.9922, 0.8824, 0. , 0. , 0. , 0. ], [ 0. , 0. , 0. , 0.3333, 0.9882, 0.9765, 0.5725, 0.1882, 0.1137, 0.3333], [ 0. , 0. , 0. , 0.3333, 0.9882, 0.9882, 0.9882, 0.898 , 0.8431, 0.9882], [ 0. , 0. , 0. , 0.1098, 0.7804, 0.9882, 0.9882, 0.9922, 0.9882, 0.9882], [ 0. , 0. , 0. , 0. , 0.098 , 0.502 , 0.9882, 0.9922, 0.9882, 0.5529], [ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ], [ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ], [ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ], [ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ]])
corrtop = correlate(images[0], top)
corrtop[dims]
array([[ 0. , 0. , 0. , 0. , 0. , 0.1882, 0.9216, 0.9765, 0.7843, -0.2392], [ 0. , 0. , 0. , 0. , 0.149 , 0.6078, 0.6667, 0.4431, -0.1882, -0.6196], [ 0. , 0. , 0. , 0.0275, 0.5765, 0.9176, 0.8392, -0.3451, -1.4275, -1.5961], [ 0. , 0. , 0. , 0.1961, 0.4863, 0.4863, -0.4039, -0.9725, -1.0471, -0.4627], [ 0. , 0. , 0. , 0.5529, 0.5569, 0.3137, -0.4863, -0.4902, -0.2471, 0. ], [ 0. , 0. , 0.298 , 0.4863, 0.4824, -0.1216, -0.3098, -0.3059, 0. , 0. ], [ 0. , 0. , 0.0353, 0.0588, -0.0275, -0.4039, -0.4275, -0.3412, 0. , 0. ], [ 0. , 0. , 0. , 0. , -0.0275, -0.1255, -0.1255, -0.098 , 0. , 0. ], [ 0. , 0. , 0. , 0. , -0.3059, -0.3059, -0.3059, 0. , 0. , 0. ], [ 0. , 0. , 0.0039, 0.0078, 0.3216, 0.3176, 0.3137, 0. , 0. , 0. ], [ 0. , 0. , -0.0039, -0.0078, 0.0863, 0.6627, 0.8549, 0.8745, 0.6353, 1.1451], [ 0. , 0. , 0. , 0. , 0.0118, 0.4275, 1.1373, 1.8549, 2.0941, 1.6745], [ 0. , 0. , -0.2235, -0.4314, -0.4314, -0.2078, 0.0941, 0.2392, 0.2392, 0.0706], [ 0. , 0. , -0.1098, -0.7922, -1.2784, -1.1686, -0.4863, 0. , -0.4353, -1.2039], [ 0. , 0. , 0. , -0.098 , -0.6 , -1.5882, -2.4824, -2.9686, -2.5333, -1.6863], [ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ], [ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ], [ 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. ]])
plot(corrtop)
np.rot90(top, 1)
array([[-1, 1, 0], [-1, 1, 0], [-1, 1, 0]])
convtop = convolve(images[0], np.rot90(top,2))
plot(convtop)
np.allclose(convtop, corrtop)
True
straights=[np.rot90(top,i) for i in range(4)]
plots(straights)
br=[[ 0, 0, 1],
[ 0, 1,-1.5],
[ 1,-1.5, 0]]
diags = [np.rot90(br,i) for i in range(4)]
plots(diags)
rots = straights + diags
corrs = [correlate(images[0], rot) for rot in rots]
plots(corrs)
eights=[images[i] for i in range(n) if labels[i]==8]
ones=[images[i] for i in range(n) if labels[i]==1]
plots(eights[:5])
plots(ones[:5])
def normalize(arr): return (arr-arr.mean())/arr.std()
filts8 = np.array([ims.mean(axis=0) for ims in pool8])
filts8 = normalize(filts8)
plots(filts8)
pool1 = [np.array([pool(correlate(im, rot)) for im in ones]) for rot in rots]
filts1 = np.array([ims.mean(axis=0) for ims in pool1])
filts1 = normalize(filts1)
plots(filts1)
def pool_corr(im): return np.array([pool(correlate(im, rot)) for rot in rots])
plots(pool_corr(eights[0]))
def sse(a,b): return ((a-b)**2).sum()
def is8_n2(im): return 1 if sse(pool_corr(im),filts1) > sse(pool_corr(im),filts8) else 0
sse(pool_corr(eights[0]), filts8), sse(pool_corr(eights[0]), filts1)
(126.77776, 181.26105)
[np.array([is8_n2(im) for im in ims]).sum() for ims in [eights,ones]]
[5223, 287]
[np.array([(1-is8_n2(im)) for im in ims]).sum() for ims in [eights,ones]]
[166, 5892]
def n1(a,b): return (np.fabs(a-b)).sum()
def is8_n1(im): return 1 if n1(pool_corr(im),filts1) > n1(pool_corr(im),filts8) else 0
[np.array([is8_n1(im) for im in ims]).sum() for ims in [eights,ones]]
[np.array([(1-is8_n1(im)) for im in ims]).sum() for ims in [eights,ones]]