from pylab import *
from scipy.ndimage import filters
def F(a): return array(a,'f')
figsize(12,6)
Let's start with a simple image processing problem: we want to recover a sharp image from a blurred, noisy image. We first construct a sharp image.
temp = F(mean(imread("page.png"),2))
roi = (slice(500,1000),slice(500,1000))
target = F(1-(temp[roi]>0.5))
imshow(target)
<matplotlib.image.AxesImage at 0x4500250>
Now we artificially blur and degrade the image.
image = filters.gaussian_filter(target+0.2*randn(*target.shape),3.0)
image -= amin(image); image /= amax(image)
imshow(image)
<matplotlib.image.AxesImage at 0xccfdf10>
The simplest convolutional neural network has no hidden layer; it is the equivalent of a sigmoid regression. The formula is:
$$ Y = \sigma(F*X+\theta) $$Here
def sigmoid(x): return 1/(1+exp(-x))
r=10
filter = F(0.01*randn(2*r+1,2*r+1))
When we initialize the filter with random weights, this is the kind of processor we get.
pred = sigmoid(filters.convolve(image,filter))
subplot(121); imshow(pred)
subplot(122); imshow(pred,vmin=0,vmax=1)
amin(pred),amax(pred)
(0.43267395210529414, 0.50682288189825297)
Now, we can view the convolution basically as a large number of independent training problems. Consider the image $I_\hat{p}$, which is the image $I$ shifted by $-p$. If we keep the filter fixed and shift the image, then the output $C$ at pixel $p$ is given by:
$$Y(p) = F \cdot X_{\hat{p}} + \theta$$Generally, the filter $F$ has a small footprint, meaning that it is zero outside a small region around the origin.
In essence, this problem is just like training with lots of separate training instances, except that we're trying to use a convolution operation to implement this. We need to keep track of the coordinates in the right way.
$$ \begin{eqnarray} \frac{\partial}{\partial F_{ij}} \sum_p (~T(p)-Y(p)~)^2 &=& \sum 2(T(p)-Y(p)) ~~ \sigma'(F\cdot I_{\hat{p}}) ~~ I_{\hat{p}}\\\ &=&\sum 2(T(p)-Y(p)) ~~ Y(p)(1-Y(p)) ~~ I_{\hat{p},i,j} \end{eqnarray} $$Here, we define $\delta(p)$ as before:
$$ \delta(p) = 2(T(p)-Y(p)) ~~ Y(p)(1-Y(p)) $$delta0 = (target-pred)
subplot(121); imshow(delta0)
delta = (target-pred)*pred*(1-pred)
subplot(122); imshow(delta)
<matplotlib.image.AxesImage at 0xcd7aa10>
This implements the update rule, the change to each individual weight.
dw = array([[sum(delta*roll(roll(image,i,0),j,1)) for j in range(-r,r+1)] for i in range(-r,r+1)])
dw /= prod(image.shape)
imshow(dw)
amin(dw),amax(dw),dw.dtype
(-0.020560804507159271, -0.0030023010284572597, dtype('float64'))
Now we perform the updates multiple times. Note that this is like a batch gradient update over all pixels in the image, so each "iter" is really more like an "epoch".
theta = 0.0
for iter in range(1000):
pred = sigmoid(filters.convolve(image,filter)+theta)
err = sum((pred-target)**2)
delta = (pred-target)*pred*(1-pred)
delta /= prod(image.shape)
dw = array([[sum(delta*roll(roll(image,i,0),j,1)) for j in range(-r,r+1)] for i in range(-r,r+1)])
if iter%50==0:
print iter,err,":",(amin(pred),amax(pred)),sum(abs(delta)),(amin(dw),amax(dw)),(amin(filter),amax(filter)),theta
filter -= dw
theta -= sum(delta)
0 24760.208943 : (0.0035678377385413352, 0.9978953848980896) 0.0458853929868 (0.0053400223104518819, 0.0072862025192962669) (-0.12178462, 0.29784608) 0.0 50 20176.1372269 : (0.00066173566353980708, 0.99803984359858355) 0.0355319773236 (-0.0012748398445276302, 0.00045858820214109207) (-0.14058399, 0.36322036) -0.260421732761 100 19031.330067 : (0.00048385370845032913, 0.99854400091214424) 0.0326746901224 (-0.0010332582710076623, 0.00040143795443587651) (-0.1455725, 0.41569439) -0.454887283628 150 18198.192199 : (0.0004155409322442416, 0.99872334330467938) 0.0308582819579 (-0.00088536597466849765, 0.00036149074890512781) (-0.15812871, 0.45914328) -0.622232347642 200 17539.3167994 : (0.00038990543483953337, 0.99878464293543867) 0.0295672706387 (-0.00078260360495212408, 0.0003293243177035594) (-0.17461969, 0.50059199) -0.772475198537 250 16994.1057221 : (0.00038608381544868965, 0.99878856138959937) 0.0285879240528 (-0.00070586597014078928, 0.00030194455200845989) (-0.18923596, 0.53775758) -0.910363441817 300 16529.5810543 : (0.00037515339279886427, 0.99875869969945386) 0.0278121315518 (-0.00064585058238400637, 0.00027882636954881774) (-0.20231509, 0.57152367) -1.03855836321 350 16125.6248064 : (0.00037309558600197144, 0.99870606309593635) 0.0271780527159 (-0.00059739241228377329, 0.00025924393283338615) (-0.21412554, 0.60258859) -1.15873135473 400 15769.0215644 : (0.00037939348845828587, 0.99863653733030056) 0.026647204231 (-0.00055734803392893924, 0.00024253697679875588) (-0.22487944, 0.63144696) -1.27202312692 450 15450.6061906 : (0.00039192129822601023, 0.9985536631379115) 0.026194124504 (-0.00052366992786349181, 0.00022816327581808566) (-0.23474531, 0.6584661) -1.37926288474 500 15163.7442414 : (0.0004093551676040092, 0.99849197991033856) 0.0258011740743 (-0.00049495441071344083, 0.00021569034133553074) (-0.24385791, 0.68392771) -1.48108354287 550 14903.4588846 : (0.00043080094880172627, 0.99849865740375987) 0.0254556803338 (-0.00047019424500519458, 0.00020477904207624254) (-0.2523264, 0.70805389) -1.57798726124 600 14665.9044865 : (0.00045560509780877642, 0.99849411870271354) 0.0251482815005 (-0.00044864629512610902, 0.0001951601962971246) (-0.26023978, 0.73102361) -1.67038494714 650 14448.0330459 : (0.00048325261579012488, 0.99848106919570145) 0.0248718997265 (-0.00042974212412478607, 0.00018662198306495541) (-0.26767153, 0.7529828) -1.75862145479 700 14247.3764536 : (0.00051330978740367025, 0.99846150259590782) 0.0246210815827 (-0.0004130403155082855, 0.00017935258926101555) (-0.2746824, 0.77405226) -1.84299226766 750 14061.9001809 : (0.00054539353479446818, 0.99843693083452123) 0.0243915641533 (-0.00039819017585347391, 0.00017361989512953811) (-0.28132361, 0.79433364) -1.92375494443 800 13889.9005715 : (0.00057915236730348807, 0.99840853151815401) 0.0241799662986 (-0.00038490986354262734, 0.0001684039036092902) (-0.28763828, 0.81391168) -2.00113722712 850 13729.9336051 : (0.00061425805257500328, 0.99837723930106792) 0.023983579772 (-0.00037296916839079324, 0.00016363017738825521) (-0.29366297, 0.83285952) -2.07534289738 900 13580.7621414 : (0.00062288689370889759, 0.99834380726690386) 0.0238002149688 (-0.00036217905485174084, 0.0001592367759027625) (-0.29942873, 0.85123914) -2.14655607795 950 13441.3162875 : (0.00062606354808537336, 0.99830885078316589) 0.0236280854244 (-0.00035238058175589759, 0.00015517494449664518) (-0.30496234, 0.86910409) -2.21494452426
As you can see, the nonlinear filter that we computed restores some of the crispness in the text. However, it is far from a good filter; for a good filter, we really need multiple units. In addition, we have to be very careful about scaling the inputs and the outputs, learning rates, etc.
figsize(20,6)
subplot(131); imshow(target)
subplot(132); imshow(image)
subplot(133); imshow(sigmoid(filters.convolve(image,filter)+theta))
<matplotlib.image.AxesImage at 0xcd0df50>
Looking at the filter itself is also interesting:
imshow(filter)
<matplotlib.image.AxesImage at 0xfe8c950>
Note that this filter combines a number of sources of information: