from keras.datasets import mnist
import numpy as np
import pandas as pd
# Base class
class Layer:
def __init__(self):
self.input = None
self.output = None
# computes the output Y of a layer for a given input X
def forward_propagation(self, input):
raise NotImplementedError
# computes dE/dX for a given dE/dY (and update parameters if any)
def backward_propagation(self, output_error, learning_rate):
raise NotImplementedError
class FCLayer(Layer):
def __init__(self, input_size, output_size):
self.weights = np.random.rand(input_size, output_size) - 0.5
self.bias = np.random.rand(1, output_size) - 0.5
def forward_propagation(self,input_data):
self.input = input_data
self.output = np.dot(self.input,self.weights)+self.bias
return(self.output)
## compute dE/dw, de/dB for a given outputt erroor. Returns input error to be given t the next layer
def backward_propagation(self,output_error, learning_rate):
input_error = np.dot(output_error, self.weights.T)
weights_error = np.dot(self.input.T, output_error)
self.weights -= learning_rate * weights_error
self.bias -= learning_rate * output_error
return(input_error)
class ActivationLayer(Layer):
def __init__(self):
self.activation = lambda x: np.tanh(x)
self.activation_prime = lambda x: 1-np.tanh(x)**2;
def forward_propagation(self, input_data):
self.input = input_data
self.output = self.activation(self.input)
return(self.output)
def backward_propagation(self, output_error, learning_rate):
return(self.activation_prime(self.input) * output_error)
# loss function and its derivative
def mse(y_true, y_pred):
return np.mean(np.power(y_true - y_pred, 2));
def mse_prime(y_true, y_pred):
return 2*(y_pred-y_true)/y_true.size;
class Network:
def __init__(self):
self.layers=[]
self.loss=None
self.loss_prime=None
def add(self,layer):
self.layers.append(layer)
def use(self,loss,loss_prime):
self.loss = loss
self.loss_prime = loss_prime
def predict(self,input_data):
result=[]
for i in range(len(input_data)):
output=input_data[i]
for layer in self.layers:
output= layer.forward_propagation(output)
result.append(output)
return(result)
def fit(self, x_train, y_train, epochs, learning_rate):
samples = len(x_train)
errors=[]
for i in range(epochs):
err = 0
for j in range(samples):
output = x_train[j]
for layer in self.layers:
output = layer.forward_propagation(output)
err += self.loss(y_train[j], output)
error = self.loss_prime(y_train[j], output)
for layer in reversed(self.layers):
error = layer.backward_propagation(error, learning_rate)
errors.append(err/samples)
print('epoch %d/%d error=%f' % (i+1, epochs, err))
return(errors)
x_train = np.array([[[0,0]], [[0,1]], [[1,0]], [[1,1]]])
y_train = np.array([[[0]], [[1]], [[1]], [[0]]])
net = Network()
net.add(FCLayer(2,5))
net.add(ActivationLayer())
net.add(FCLayer(5,1))
net.add(ActivationLayer())
net.use(mse, mse_prime)
err = net.fit(x_train, y_train, epochs=1000, learning_rate=0.1)
epoch 1/1000 error=3.439531 epoch 2/1000 error=1.469290 epoch 3/1000 error=1.249098 epoch 4/1000 error=1.221891 epoch 5/1000 error=1.213275 epoch 6/1000 error=1.208269 epoch 7/1000 error=1.204283 epoch 8/1000 error=1.200643 epoch 9/1000 error=1.197125 epoch 10/1000 error=1.193641 epoch 11/1000 error=1.190152 epoch 12/1000 error=1.186639 epoch 13/1000 error=1.183090 epoch 14/1000 error=1.179496 epoch 15/1000 error=1.175854 epoch 16/1000 error=1.172157 epoch 17/1000 error=1.168403 epoch 18/1000 error=1.164588 epoch 19/1000 error=1.160709 epoch 20/1000 error=1.156765 epoch 21/1000 error=1.152753 epoch 22/1000 error=1.148673 epoch 23/1000 error=1.144523 epoch 24/1000 error=1.140304 epoch 25/1000 error=1.136016 epoch 26/1000 error=1.131659 epoch 27/1000 error=1.127236 epoch 28/1000 error=1.122749 epoch 29/1000 error=1.118200 epoch 30/1000 error=1.113594 epoch 31/1000 error=1.108935 epoch 32/1000 error=1.104229 epoch 33/1000 error=1.099480 epoch 34/1000 error=1.094695 epoch 35/1000 error=1.089882 epoch 36/1000 error=1.085048 epoch 37/1000 error=1.080200 epoch 38/1000 error=1.075348 epoch 39/1000 error=1.070499 epoch 40/1000 error=1.065664 epoch 41/1000 error=1.060849 epoch 42/1000 error=1.056065 epoch 43/1000 error=1.051319 epoch 44/1000 error=1.046620 epoch 45/1000 error=1.041977 epoch 46/1000 error=1.037396 epoch 47/1000 error=1.032884 epoch 48/1000 error=1.028448 epoch 49/1000 error=1.024094 epoch 50/1000 error=1.019825 epoch 51/1000 error=1.015647 epoch 52/1000 error=1.011563 epoch 53/1000 error=1.007575 epoch 54/1000 error=1.003685 epoch 55/1000 error=0.999895 epoch 56/1000 error=0.996204 epoch 57/1000 error=0.992613 epoch 58/1000 error=0.989120 epoch 59/1000 error=0.985724 epoch 60/1000 error=0.982424 epoch 61/1000 error=0.979217 epoch 62/1000 error=0.976100 epoch 63/1000 error=0.973071 epoch 64/1000 error=0.970126 epoch 65/1000 error=0.967262 epoch 66/1000 error=0.964476 epoch 67/1000 error=0.961764 epoch 68/1000 error=0.959122 epoch 69/1000 error=0.956546 epoch 70/1000 error=0.954034 epoch 71/1000 error=0.951581 epoch 72/1000 error=0.949184 epoch 73/1000 error=0.946839 epoch 74/1000 error=0.944543 epoch 75/1000 error=0.942292 epoch 76/1000 error=0.940085 epoch 77/1000 error=0.937916 epoch 78/1000 error=0.935783 epoch 79/1000 error=0.933684 epoch 80/1000 error=0.931615 epoch 81/1000 error=0.929575 epoch 82/1000 error=0.927559 epoch 83/1000 error=0.925565 epoch 84/1000 error=0.923592 epoch 85/1000 error=0.921636 epoch 86/1000 error=0.919696 epoch 87/1000 error=0.917768 epoch 88/1000 error=0.915851 epoch 89/1000 error=0.913942 epoch 90/1000 error=0.912039 epoch 91/1000 error=0.910140 epoch 92/1000 error=0.908242 epoch 93/1000 error=0.906344 epoch 94/1000 error=0.904442 epoch 95/1000 error=0.902535 epoch 96/1000 error=0.900620 epoch 97/1000 error=0.898696 epoch 98/1000 error=0.896758 epoch 99/1000 error=0.894806 epoch 100/1000 error=0.892836 epoch 101/1000 error=0.890845 epoch 102/1000 error=0.888832 epoch 103/1000 error=0.886793 epoch 104/1000 error=0.884725 epoch 105/1000 error=0.882626 epoch 106/1000 error=0.880491 epoch 107/1000 error=0.878319 epoch 108/1000 error=0.876105 epoch 109/1000 error=0.873845 epoch 110/1000 error=0.871537 epoch 111/1000 error=0.869176 epoch 112/1000 error=0.866757 epoch 113/1000 error=0.864278 epoch 114/1000 error=0.861732 epoch 115/1000 error=0.859115 epoch 116/1000 error=0.856422 epoch 117/1000 error=0.853648 epoch 118/1000 error=0.850786 epoch 119/1000 error=0.847830 epoch 120/1000 error=0.844773 epoch 121/1000 error=0.841609 epoch 122/1000 error=0.838330 epoch 123/1000 error=0.834928 epoch 124/1000 error=0.831394 epoch 125/1000 error=0.827718 epoch 126/1000 error=0.823891 epoch 127/1000 error=0.819901 epoch 128/1000 error=0.815738 epoch 129/1000 error=0.811388 epoch 130/1000 error=0.806838 epoch 131/1000 error=0.802073 epoch 132/1000 error=0.797079 epoch 133/1000 error=0.791837 epoch 134/1000 error=0.786330 epoch 135/1000 error=0.780538 epoch 136/1000 error=0.774440 epoch 137/1000 error=0.768014 epoch 138/1000 error=0.761234 epoch 139/1000 error=0.754075 epoch 140/1000 error=0.746509 epoch 141/1000 error=0.738505 epoch 142/1000 error=0.730032 epoch 143/1000 error=0.721055 epoch 144/1000 error=0.711538 epoch 145/1000 error=0.701443 epoch 146/1000 error=0.690730 epoch 147/1000 error=0.679356 epoch 148/1000 error=0.667280 epoch 149/1000 error=0.654457 epoch 150/1000 error=0.640844 epoch 151/1000 error=0.626401 epoch 152/1000 error=0.611087 epoch 153/1000 error=0.594869 epoch 154/1000 error=0.577720 epoch 155/1000 error=0.559625 epoch 156/1000 error=0.540582 epoch 157/1000 error=0.520606 epoch 158/1000 error=0.499736 epoch 159/1000 error=0.478038 epoch 160/1000 error=0.455608 epoch 161/1000 error=0.432576 epoch 162/1000 error=0.409107 epoch 163/1000 error=0.385399 epoch 164/1000 error=0.361679 epoch 165/1000 error=0.338191 epoch 166/1000 error=0.315186 epoch 167/1000 error=0.292907 epoch 168/1000 error=0.271572 epoch 169/1000 error=0.251362 epoch 170/1000 error=0.232410 epoch 171/1000 error=0.214797 epoch 172/1000 error=0.198556 epoch 173/1000 error=0.183675 epoch 174/1000 error=0.170107 epoch 175/1000 error=0.157781 epoch 176/1000 error=0.146609 epoch 177/1000 error=0.136496 epoch 178/1000 error=0.127345 epoch 179/1000 error=0.119061 epoch 180/1000 error=0.111556 epoch 181/1000 error=0.104747 epoch 182/1000 error=0.098560 epoch 183/1000 error=0.092927 epoch 184/1000 error=0.087789 epoch 185/1000 error=0.083091 epoch 186/1000 error=0.078788 epoch 187/1000 error=0.074837 epoch 188/1000 error=0.071202 epoch 189/1000 error=0.067850 epoch 190/1000 error=0.064754 epoch 191/1000 error=0.061887 epoch 192/1000 error=0.059227 epoch 193/1000 error=0.056755 epoch 194/1000 error=0.054454 epoch 195/1000 error=0.052307 epoch 196/1000 error=0.050301 epoch 197/1000 error=0.048423 epoch 198/1000 error=0.046663 epoch 199/1000 error=0.045010 epoch 200/1000 error=0.043456 epoch 201/1000 error=0.041993 epoch 202/1000 error=0.040613 epoch 203/1000 error=0.039310 epoch 204/1000 error=0.038079 epoch 205/1000 error=0.036913 epoch 206/1000 error=0.035809 epoch 207/1000 error=0.034761 epoch 208/1000 error=0.033766 epoch 209/1000 error=0.032820 epoch 210/1000 error=0.031920 epoch 211/1000 error=0.031062 epoch 212/1000 error=0.030244 epoch 213/1000 error=0.029464 epoch 214/1000 error=0.028718 epoch 215/1000 error=0.028005 epoch 216/1000 error=0.027324 epoch 217/1000 error=0.026671 epoch 218/1000 error=0.026045 epoch 219/1000 error=0.025445 epoch 220/1000 error=0.024869 epoch 221/1000 error=0.024316 epoch 222/1000 error=0.023784 epoch 223/1000 error=0.023273 epoch 224/1000 error=0.022782 epoch 225/1000 error=0.022308 epoch 226/1000 error=0.021852 epoch 227/1000 error=0.021412 epoch 228/1000 error=0.020988 epoch 229/1000 error=0.020579 epoch 230/1000 error=0.020184 epoch 231/1000 error=0.019802 epoch 232/1000 error=0.019433 epoch 233/1000 error=0.019076 epoch 234/1000 error=0.018731 epoch 235/1000 error=0.018397 epoch 236/1000 error=0.018073 epoch 237/1000 error=0.017760 epoch 238/1000 error=0.017456 epoch 239/1000 error=0.017162 epoch 240/1000 error=0.016876 epoch 241/1000 error=0.016599 epoch 242/1000 error=0.016330 epoch 243/1000 error=0.016069 epoch 244/1000 error=0.015815 epoch 245/1000 error=0.015568 epoch 246/1000 error=0.015329 epoch 247/1000 error=0.015095 epoch 248/1000 error=0.014869 epoch 249/1000 error=0.014648 epoch 250/1000 error=0.014433 epoch 251/1000 error=0.014224 epoch 252/1000 error=0.014020 epoch 253/1000 error=0.013822 epoch 254/1000 error=0.013628 epoch 255/1000 error=0.013440 epoch 256/1000 error=0.013256 epoch 257/1000 error=0.013076 epoch 258/1000 error=0.012901 epoch 259/1000 error=0.012730 epoch 260/1000 error=0.012564 epoch 261/1000 error=0.012401 epoch 262/1000 error=0.012242 epoch 263/1000 error=0.012086 epoch 264/1000 error=0.011935 epoch 265/1000 error=0.011786 epoch 266/1000 error=0.011641 epoch 267/1000 error=0.011500 epoch 268/1000 error=0.011361 epoch 269/1000 error=0.011225 epoch 270/1000 error=0.011092 epoch 271/1000 error=0.010963 epoch 272/1000 error=0.010835 epoch 273/1000 error=0.010711 epoch 274/1000 error=0.010589 epoch 275/1000 error=0.010470 epoch 276/1000 error=0.010353 epoch 277/1000 error=0.010238 epoch 278/1000 error=0.010126 epoch 279/1000 error=0.010016 epoch 280/1000 error=0.009908 epoch 281/1000 error=0.009802 epoch 282/1000 error=0.009699 epoch 283/1000 error=0.009597 epoch 284/1000 error=0.009497 epoch 285/1000 error=0.009399 epoch 286/1000 error=0.009303 epoch 287/1000 error=0.009209 epoch 288/1000 error=0.009116 epoch 289/1000 error=0.009025 epoch 290/1000 error=0.008936 epoch 291/1000 error=0.008849 epoch 292/1000 error=0.008763 epoch 293/1000 error=0.008678 epoch 294/1000 error=0.008595 epoch 295/1000 error=0.008513 epoch 296/1000 error=0.008433 epoch 297/1000 error=0.008354 epoch 298/1000 error=0.008277 epoch 299/1000 error=0.008201 epoch 300/1000 error=0.008126 epoch 301/1000 error=0.008052 epoch 302/1000 error=0.007980 epoch 303/1000 error=0.007909 epoch 304/1000 error=0.007839 epoch 305/1000 error=0.007770 epoch 306/1000 error=0.007702 epoch 307/1000 error=0.007635 epoch 308/1000 error=0.007569 epoch 309/1000 error=0.007505 epoch 310/1000 error=0.007441 epoch 311/1000 error=0.007378 epoch 312/1000 error=0.007317 epoch 313/1000 error=0.007256 epoch 314/1000 error=0.007196 epoch 315/1000 error=0.007137 epoch 316/1000 error=0.007079 epoch 317/1000 error=0.007022 epoch 318/1000 error=0.006966 epoch 319/1000 error=0.006910 epoch 320/1000 error=0.006855 epoch 321/1000 error=0.006802 epoch 322/1000 error=0.006748 epoch 323/1000 error=0.006696 epoch 324/1000 error=0.006644 epoch 325/1000 error=0.006594 epoch 326/1000 error=0.006543 epoch 327/1000 error=0.006494 epoch 328/1000 error=0.006445 epoch 329/1000 error=0.006397 epoch 330/1000 error=0.006349 epoch 331/1000 error=0.006303 epoch 332/1000 error=0.006256 epoch 333/1000 error=0.006211 epoch 334/1000 error=0.006166 epoch 335/1000 error=0.006122 epoch 336/1000 error=0.006078 epoch 337/1000 error=0.006035 epoch 338/1000 error=0.005992 epoch 339/1000 error=0.005950 epoch 340/1000 error=0.005908 epoch 341/1000 error=0.005867 epoch 342/1000 error=0.005827 epoch 343/1000 error=0.005787 epoch 344/1000 error=0.005747 epoch 345/1000 error=0.005709 epoch 346/1000 error=0.005670 epoch 347/1000 error=0.005632 epoch 348/1000 error=0.005594 epoch 349/1000 error=0.005557 epoch 350/1000 error=0.005521 epoch 351/1000 error=0.005485 epoch 352/1000 error=0.005449 epoch 353/1000 error=0.005414 epoch 354/1000 error=0.005379 epoch 355/1000 error=0.005344 epoch 356/1000 error=0.005310 epoch 357/1000 error=0.005276 epoch 358/1000 error=0.005243 epoch 359/1000 error=0.005210 epoch 360/1000 error=0.005178 epoch 361/1000 error=0.005146 epoch 362/1000 error=0.005114 epoch 363/1000 error=0.005082 epoch 364/1000 error=0.005051 epoch 365/1000 error=0.005021 epoch 366/1000 error=0.004990 epoch 367/1000 error=0.004960 epoch 368/1000 error=0.004931 epoch 369/1000 error=0.004901 epoch 370/1000 error=0.004872 epoch 371/1000 error=0.004843 epoch 372/1000 error=0.004815 epoch 373/1000 error=0.004787 epoch 374/1000 error=0.004759 epoch 375/1000 error=0.004732 epoch 376/1000 error=0.004704 epoch 377/1000 error=0.004678 epoch 378/1000 error=0.004651 epoch 379/1000 error=0.004625 epoch 380/1000 error=0.004598 epoch 381/1000 error=0.004573 epoch 382/1000 error=0.004547 epoch 383/1000 error=0.004522 epoch 384/1000 error=0.004497 epoch 385/1000 error=0.004472 epoch 386/1000 error=0.004448 epoch 387/1000 error=0.004423 epoch 388/1000 error=0.004399 epoch 389/1000 error=0.004376 epoch 390/1000 error=0.004352 epoch 391/1000 error=0.004329 epoch 392/1000 error=0.004306 epoch 393/1000 error=0.004283 epoch 394/1000 error=0.004260 epoch 395/1000 error=0.004238 epoch 396/1000 error=0.004216 epoch 397/1000 error=0.004194 epoch 398/1000 error=0.004172 epoch 399/1000 error=0.004150 epoch 400/1000 error=0.004129 epoch 401/1000 error=0.004108 epoch 402/1000 error=0.004087 epoch 403/1000 error=0.004066 epoch 404/1000 error=0.004046 epoch 405/1000 error=0.004025 epoch 406/1000 error=0.004005 epoch 407/1000 error=0.003985 epoch 408/1000 error=0.003965 epoch 409/1000 error=0.003946 epoch 410/1000 error=0.003926 epoch 411/1000 error=0.003907 epoch 412/1000 error=0.003888 epoch 413/1000 error=0.003869 epoch 414/1000 error=0.003850 epoch 415/1000 error=0.003832 epoch 416/1000 error=0.003813 epoch 417/1000 error=0.003795 epoch 418/1000 error=0.003777 epoch 419/1000 error=0.003759 epoch 420/1000 error=0.003741 epoch 421/1000 error=0.003724 epoch 422/1000 error=0.003706 epoch 423/1000 error=0.003689 epoch 424/1000 error=0.003672 epoch 425/1000 error=0.003655 epoch 426/1000 error=0.003638 epoch 427/1000 error=0.003621 epoch 428/1000 error=0.003605 epoch 429/1000 error=0.003588 epoch 430/1000 error=0.003572 epoch 431/1000 error=0.003556 epoch 432/1000 error=0.003540 epoch 433/1000 error=0.003524 epoch 434/1000 error=0.003508 epoch 435/1000 error=0.003493 epoch 436/1000 error=0.003477 epoch 437/1000 error=0.003462 epoch 438/1000 error=0.003447 epoch 439/1000 error=0.003431 epoch 440/1000 error=0.003416 epoch 441/1000 error=0.003402 epoch 442/1000 error=0.003387 epoch 443/1000 error=0.003372 epoch 444/1000 error=0.003358 epoch 445/1000 error=0.003343 epoch 446/1000 error=0.003329 epoch 447/1000 error=0.003315 epoch 448/1000 error=0.003301 epoch 449/1000 error=0.003287 epoch 450/1000 error=0.003273 epoch 451/1000 error=0.003259 epoch 452/1000 error=0.003246 epoch 453/1000 error=0.003232 epoch 454/1000 error=0.003219 epoch 455/1000 error=0.003205 epoch 456/1000 error=0.003192 epoch 457/1000 error=0.003179 epoch 458/1000 error=0.003166 epoch 459/1000 error=0.003153 epoch 460/1000 error=0.003140 epoch 461/1000 error=0.003128 epoch 462/1000 error=0.003115 epoch 463/1000 error=0.003103 epoch 464/1000 error=0.003090 epoch 465/1000 error=0.003078 epoch 466/1000 error=0.003066 epoch 467/1000 error=0.003053 epoch 468/1000 error=0.003041 epoch 469/1000 error=0.003029 epoch 470/1000 error=0.003018 epoch 471/1000 error=0.003006 epoch 472/1000 error=0.002994 epoch 473/1000 error=0.002982 epoch 474/1000 error=0.002971 epoch 475/1000 error=0.002959 epoch 476/1000 error=0.002948 epoch 477/1000 error=0.002937 epoch 478/1000 error=0.002926 epoch 479/1000 error=0.002914 epoch 480/1000 error=0.002903 epoch 481/1000 error=0.002892 epoch 482/1000 error=0.002882 epoch 483/1000 error=0.002871 epoch 484/1000 error=0.002860 epoch 485/1000 error=0.002849 epoch 486/1000 error=0.002839 epoch 487/1000 error=0.002828 epoch 488/1000 error=0.002818 epoch 489/1000 error=0.002807 epoch 490/1000 error=0.002797 epoch 491/1000 error=0.002787 epoch 492/1000 error=0.002777 epoch 493/1000 error=0.002766 epoch 494/1000 error=0.002756 epoch 495/1000 error=0.002746 epoch 496/1000 error=0.002737 epoch 497/1000 error=0.002727 epoch 498/1000 error=0.002717 epoch 499/1000 error=0.002707 epoch 500/1000 error=0.002698 epoch 501/1000 error=0.002688 epoch 502/1000 error=0.002679 epoch 503/1000 error=0.002669 epoch 504/1000 error=0.002660 epoch 505/1000 error=0.002650 epoch 506/1000 error=0.002641 epoch 507/1000 error=0.002632 epoch 508/1000 error=0.002623 epoch 509/1000 error=0.002614 epoch 510/1000 error=0.002605 epoch 511/1000 error=0.002596 epoch 512/1000 error=0.002587 epoch 513/1000 error=0.002578 epoch 514/1000 error=0.002569 epoch 515/1000 error=0.002560 epoch 516/1000 error=0.002552 epoch 517/1000 error=0.002543 epoch 518/1000 error=0.002535 epoch 519/1000 error=0.002526 epoch 520/1000 error=0.002518 epoch 521/1000 error=0.002509 epoch 522/1000 error=0.002501 epoch 523/1000 error=0.002493 epoch 524/1000 error=0.002484 epoch 525/1000 error=0.002476 epoch 526/1000 error=0.002468 epoch 527/1000 error=0.002460 epoch 528/1000 error=0.002452 epoch 529/1000 error=0.002444 epoch 530/1000 error=0.002436 epoch 531/1000 error=0.002428 epoch 532/1000 error=0.002420 epoch 533/1000 error=0.002412 epoch 534/1000 error=0.002404 epoch 535/1000 error=0.002397 epoch 536/1000 error=0.002389 epoch 537/1000 error=0.002381 epoch 538/1000 error=0.002374 epoch 539/1000 error=0.002366 epoch 540/1000 error=0.002359 epoch 541/1000 error=0.002351 epoch 542/1000 error=0.002344 epoch 543/1000 error=0.002337 epoch 544/1000 error=0.002329 epoch 545/1000 error=0.002322 epoch 546/1000 error=0.002315 epoch 547/1000 error=0.002308 epoch 548/1000 error=0.002300 epoch 549/1000 error=0.002293 epoch 550/1000 error=0.002286 epoch 551/1000 error=0.002279 epoch 552/1000 error=0.002272 epoch 553/1000 error=0.002265 epoch 554/1000 error=0.002258 epoch 555/1000 error=0.002252 epoch 556/1000 error=0.002245 epoch 557/1000 error=0.002238 epoch 558/1000 error=0.002231 epoch 559/1000 error=0.002225 epoch 560/1000 error=0.002218 epoch 561/1000 error=0.002211 epoch 562/1000 error=0.002205 epoch 563/1000 error=0.002198 epoch 564/1000 error=0.002192 epoch 565/1000 error=0.002185 epoch 566/1000 error=0.002179 epoch 567/1000 error=0.002172 epoch 568/1000 error=0.002166 epoch 569/1000 error=0.002159 epoch 570/1000 error=0.002153 epoch 571/1000 error=0.002147 epoch 572/1000 error=0.002141 epoch 573/1000 error=0.002134 epoch 574/1000 error=0.002128 epoch 575/1000 error=0.002122 epoch 576/1000 error=0.002116 epoch 577/1000 error=0.002110 epoch 578/1000 error=0.002104 epoch 579/1000 error=0.002098 epoch 580/1000 error=0.002092 epoch 581/1000 error=0.002086 epoch 582/1000 error=0.002080 epoch 583/1000 error=0.002074 epoch 584/1000 error=0.002068 epoch 585/1000 error=0.002062 epoch 586/1000 error=0.002057 epoch 587/1000 error=0.002051 epoch 588/1000 error=0.002045 epoch 589/1000 error=0.002039 epoch 590/1000 error=0.002034 epoch 591/1000 error=0.002028 epoch 592/1000 error=0.002023 epoch 593/1000 error=0.002017 epoch 594/1000 error=0.002011 epoch 595/1000 error=0.002006 epoch 596/1000 error=0.002000 epoch 597/1000 error=0.001995 epoch 598/1000 error=0.001989 epoch 599/1000 error=0.001984 epoch 600/1000 error=0.001979 epoch 601/1000 error=0.001973 epoch 602/1000 error=0.001968 epoch 603/1000 error=0.001963 epoch 604/1000 error=0.001957 epoch 605/1000 error=0.001952 epoch 606/1000 error=0.001947 epoch 607/1000 error=0.001942 epoch 608/1000 error=0.001937 epoch 609/1000 error=0.001931 epoch 610/1000 error=0.001926 epoch 611/1000 error=0.001921 epoch 612/1000 error=0.001916 epoch 613/1000 error=0.001911 epoch 614/1000 error=0.001906 epoch 615/1000 error=0.001901 epoch 616/1000 error=0.001896 epoch 617/1000 error=0.001891 epoch 618/1000 error=0.001886 epoch 619/1000 error=0.001881 epoch 620/1000 error=0.001876 epoch 621/1000 error=0.001872 epoch 622/1000 error=0.001867 epoch 623/1000 error=0.001862 epoch 624/1000 error=0.001857 epoch 625/1000 error=0.001852 epoch 626/1000 error=0.001848 epoch 627/1000 error=0.001843 epoch 628/1000 error=0.001838 epoch 629/1000 error=0.001834 epoch 630/1000 error=0.001829 epoch 631/1000 error=0.001824 epoch 632/1000 error=0.001820 epoch 633/1000 error=0.001815 epoch 634/1000 error=0.001811 epoch 635/1000 error=0.001806 epoch 636/1000 error=0.001802 epoch 637/1000 error=0.001797 epoch 638/1000 error=0.001793 epoch 639/1000 error=0.001788 epoch 640/1000 error=0.001784 epoch 641/1000 error=0.001779 epoch 642/1000 error=0.001775 epoch 643/1000 error=0.001771 epoch 644/1000 error=0.001766 epoch 645/1000 error=0.001762 epoch 646/1000 error=0.001758 epoch 647/1000 error=0.001753 epoch 648/1000 error=0.001749 epoch 649/1000 error=0.001745 epoch 650/1000 error=0.001741 epoch 651/1000 error=0.001736 epoch 652/1000 error=0.001732 epoch 653/1000 error=0.001728 epoch 654/1000 error=0.001724 epoch 655/1000 error=0.001720 epoch 656/1000 error=0.001716 epoch 657/1000 error=0.001711 epoch 658/1000 error=0.001707 epoch 659/1000 error=0.001703 epoch 660/1000 error=0.001699 epoch 661/1000 error=0.001695 epoch 662/1000 error=0.001691 epoch 663/1000 error=0.001687 epoch 664/1000 error=0.001683 epoch 665/1000 error=0.001679 epoch 666/1000 error=0.001675 epoch 667/1000 error=0.001672 epoch 668/1000 error=0.001668 epoch 669/1000 error=0.001664 epoch 670/1000 error=0.001660 epoch 671/1000 error=0.001656 epoch 672/1000 error=0.001652 epoch 673/1000 error=0.001648 epoch 674/1000 error=0.001645 epoch 675/1000 error=0.001641 epoch 676/1000 error=0.001637 epoch 677/1000 error=0.001633 epoch 678/1000 error=0.001630 epoch 679/1000 error=0.001626 epoch 680/1000 error=0.001622 epoch 681/1000 error=0.001618 epoch 682/1000 error=0.001615 epoch 683/1000 error=0.001611 epoch 684/1000 error=0.001607 epoch 685/1000 error=0.001604 epoch 686/1000 error=0.001600 epoch 687/1000 error=0.001597 epoch 688/1000 error=0.001593 epoch 689/1000 error=0.001589 epoch 690/1000 error=0.001586 epoch 691/1000 error=0.001582 epoch 692/1000 error=0.001579 epoch 693/1000 error=0.001575 epoch 694/1000 error=0.001572 epoch 695/1000 error=0.001568 epoch 696/1000 error=0.001565 epoch 697/1000 error=0.001561 epoch 698/1000 error=0.001558 epoch 699/1000 error=0.001555 epoch 700/1000 error=0.001551 epoch 701/1000 error=0.001548 epoch 702/1000 error=0.001545 epoch 703/1000 error=0.001541 epoch 704/1000 error=0.001538 epoch 705/1000 error=0.001535 epoch 706/1000 error=0.001532 epoch 707/1000 error=0.001528 epoch 708/1000 error=0.001526 epoch 709/1000 error=0.001522 epoch 710/1000 error=0.001520 epoch 711/1000 error=0.001516 epoch 712/1000 error=0.001515 epoch 713/1000 error=0.001512 epoch 714/1000 error=0.001511 epoch 715/1000 error=0.001508 epoch 716/1000 error=0.001509 epoch 717/1000 error=0.001508 epoch 718/1000 error=0.001511 epoch 719/1000 error=0.001512 epoch 720/1000 error=0.001519 epoch 721/1000 error=0.001525 epoch 722/1000 error=0.001540 epoch 723/1000 error=0.001554 epoch 724/1000 error=0.001583 epoch 725/1000 error=0.001613 epoch 726/1000 error=0.001667 epoch 727/1000 error=0.001727 epoch 728/1000 error=0.001824 epoch 729/1000 error=0.001943 epoch 730/1000 error=0.002115 epoch 731/1000 error=0.002341 epoch 732/1000 error=0.002645 epoch 733/1000 error=0.003069 epoch 734/1000 error=0.003592 epoch 735/1000 error=0.004367 epoch 736/1000 error=0.005231 epoch 737/1000 error=0.006599 epoch 738/1000 error=0.007921 epoch 739/1000 error=0.010194 epoch 740/1000 error=0.011947 epoch 741/1000 error=0.015368 epoch 742/1000 error=0.017131 epoch 743/1000 error=0.021579 epoch 744/1000 error=0.022405 epoch 745/1000 error=0.027193 epoch 746/1000 error=0.026114 epoch 747/1000 error=0.030312 epoch 748/1000 error=0.027216 epoch 749/1000 error=0.030249 epoch 750/1000 error=0.025948 epoch 751/1000 error=0.027774 epoch 752/1000 error=0.023250 epoch 753/1000 error=0.024124 epoch 754/1000 error=0.020011 epoch 755/1000 error=0.020246 epoch 756/1000 error=0.016800 epoch 757/1000 error=0.016656 epoch 758/1000 error=0.013900 epoch 759/1000 error=0.013560 epoch 760/1000 error=0.011418 epoch 761/1000 error=0.011000 epoch 762/1000 error=0.009362 epoch 763/1000 error=0.008933 epoch 764/1000 error=0.007694 epoch 765/1000 error=0.007292 epoch 766/1000 error=0.006359 epoch 767/1000 error=0.006001 epoch 768/1000 error=0.005301 epoch 769/1000 error=0.004992 epoch 770/1000 error=0.004466 epoch 771/1000 error=0.004205 epoch 772/1000 error=0.003810 epoch 773/1000 error=0.003592 epoch 774/1000 error=0.003296 epoch 775/1000 error=0.003116 epoch 776/1000 error=0.002893 epoch 777/1000 error=0.002744 epoch 778/1000 error=0.002576 epoch 779/1000 error=0.002454 epoch 780/1000 error=0.002327 epoch 781/1000 error=0.002227 epoch 782/1000 error=0.002130 epoch 783/1000 error=0.002048 epoch 784/1000 error=0.001974 epoch 785/1000 error=0.001907 epoch 786/1000 error=0.001851 epoch 787/1000 error=0.001795 epoch 788/1000 error=0.001752 epoch 789/1000 error=0.001706 epoch 790/1000 error=0.001672 epoch 791/1000 error=0.001634 epoch 792/1000 error=0.001608 epoch 793/1000 error=0.001576 epoch 794/1000 error=0.001556 epoch 795/1000 error=0.001529 epoch 796/1000 error=0.001513 epoch 797/1000 error=0.001490 epoch 798/1000 error=0.001478 epoch 799/1000 error=0.001458 epoch 800/1000 error=0.001449 epoch 801/1000 error=0.001432 epoch 802/1000 error=0.001424 epoch 803/1000 error=0.001409 epoch 804/1000 error=0.001403 epoch 805/1000 error=0.001390 epoch 806/1000 error=0.001385 epoch 807/1000 error=0.001373 epoch 808/1000 error=0.001369 epoch 809/1000 error=0.001359 epoch 810/1000 error=0.001355 epoch 811/1000 error=0.001346 epoch 812/1000 error=0.001343 epoch 813/1000 error=0.001335 epoch 814/1000 error=0.001332 epoch 815/1000 error=0.001325 epoch 816/1000 error=0.001323 epoch 817/1000 error=0.001316 epoch 818/1000 error=0.001314 epoch 819/1000 error=0.001307 epoch 820/1000 error=0.001305 epoch 821/1000 error=0.001299 epoch 822/1000 error=0.001298 epoch 823/1000 error=0.001292 epoch 824/1000 error=0.001290 epoch 825/1000 error=0.001285 epoch 826/1000 error=0.001284 epoch 827/1000 error=0.001279 epoch 828/1000 error=0.001277 epoch 829/1000 error=0.001273 epoch 830/1000 error=0.001271 epoch 831/1000 error=0.001267 epoch 832/1000 error=0.001265 epoch 833/1000 error=0.001261 epoch 834/1000 error=0.001259 epoch 835/1000 error=0.001255 epoch 836/1000 error=0.001254 epoch 837/1000 error=0.001250 epoch 838/1000 error=0.001248 epoch 839/1000 error=0.001245 epoch 840/1000 error=0.001243 epoch 841/1000 error=0.001239 epoch 842/1000 error=0.001238 epoch 843/1000 error=0.001234 epoch 844/1000 error=0.001233 epoch 845/1000 error=0.001229 epoch 846/1000 error=0.001228 epoch 847/1000 error=0.001225 epoch 848/1000 error=0.001223 epoch 849/1000 error=0.001220 epoch 850/1000 error=0.001218 epoch 851/1000 error=0.001215 epoch 852/1000 error=0.001213 epoch 853/1000 error=0.001210 epoch 854/1000 error=0.001209 epoch 855/1000 error=0.001206 epoch 856/1000 error=0.001204 epoch 857/1000 error=0.001201 epoch 858/1000 error=0.001200 epoch 859/1000 error=0.001197 epoch 860/1000 error=0.001195 epoch 861/1000 error=0.001192 epoch 862/1000 error=0.001191 epoch 863/1000 error=0.001188 epoch 864/1000 error=0.001186 epoch 865/1000 error=0.001184 epoch 866/1000 error=0.001182 epoch 867/1000 error=0.001179 epoch 868/1000 error=0.001178 epoch 869/1000 error=0.001175 epoch 870/1000 error=0.001174 epoch 871/1000 error=0.001171 epoch 872/1000 error=0.001169 epoch 873/1000 error=0.001167 epoch 874/1000 error=0.001165 epoch 875/1000 error=0.001163 epoch 876/1000 error=0.001161 epoch 877/1000 error=0.001159 epoch 878/1000 error=0.001157 epoch 879/1000 error=0.001155 epoch 880/1000 error=0.001153 epoch 881/1000 error=0.001150 epoch 882/1000 error=0.001149 epoch 883/1000 error=0.001146 epoch 884/1000 error=0.001145 epoch 885/1000 error=0.001142 epoch 886/1000 error=0.001141 epoch 887/1000 error=0.001139 epoch 888/1000 error=0.001137 epoch 889/1000 error=0.001135 epoch 890/1000 error=0.001133 epoch 891/1000 error=0.001131 epoch 892/1000 error=0.001129 epoch 893/1000 error=0.001127 epoch 894/1000 error=0.001125 epoch 895/1000 error=0.001123 epoch 896/1000 error=0.001121 epoch 897/1000 error=0.001119 epoch 898/1000 error=0.001118 epoch 899/1000 error=0.001115 epoch 900/1000 error=0.001114 epoch 901/1000 error=0.001112 epoch 902/1000 error=0.001110 epoch 903/1000 error=0.001108 epoch 904/1000 error=0.001106 epoch 905/1000 error=0.001104 epoch 906/1000 error=0.001103 epoch 907/1000 error=0.001101 epoch 908/1000 error=0.001099 epoch 909/1000 error=0.001097 epoch 910/1000 error=0.001095 epoch 911/1000 error=0.001093 epoch 912/1000 error=0.001092 epoch 913/1000 error=0.001090 epoch 914/1000 error=0.001088 epoch 915/1000 error=0.001086 epoch 916/1000 error=0.001085 epoch 917/1000 error=0.001082 epoch 918/1000 error=0.001081 epoch 919/1000 error=0.001079 epoch 920/1000 error=0.001077 epoch 921/1000 error=0.001075 epoch 922/1000 error=0.001074 epoch 923/1000 error=0.001072 epoch 924/1000 error=0.001070 epoch 925/1000 error=0.001068 epoch 926/1000 error=0.001067 epoch 927/1000 error=0.001065 epoch 928/1000 error=0.001064 epoch 929/1000 error=0.001062 epoch 930/1000 error=0.001060 epoch 931/1000 error=0.001058 epoch 932/1000 error=0.001057 epoch 933/1000 error=0.001055 epoch 934/1000 error=0.001053 epoch 935/1000 error=0.001052 epoch 936/1000 error=0.001050 epoch 937/1000 error=0.001048 epoch 938/1000 error=0.001047 epoch 939/1000 error=0.001045 epoch 940/1000 error=0.001044 epoch 941/1000 error=0.001042 epoch 942/1000 error=0.001040 epoch 943/1000 error=0.001038 epoch 944/1000 error=0.001037 epoch 945/1000 error=0.001035 epoch 946/1000 error=0.001034 epoch 947/1000 error=0.001032 epoch 948/1000 error=0.001031 epoch 949/1000 error=0.001029 epoch 950/1000 error=0.001028 epoch 951/1000 error=0.001026 epoch 952/1000 error=0.001024 epoch 953/1000 error=0.001022 epoch 954/1000 error=0.001021 epoch 955/1000 error=0.001019 epoch 956/1000 error=0.001018 epoch 957/1000 error=0.001016 epoch 958/1000 error=0.001015 epoch 959/1000 error=0.001013 epoch 960/1000 error=0.001012 epoch 961/1000 error=0.001010 epoch 962/1000 error=0.001009 epoch 963/1000 error=0.001007 epoch 964/1000 error=0.001006 epoch 965/1000 error=0.001004 epoch 966/1000 error=0.001003 epoch 967/1000 error=0.001001 epoch 968/1000 error=0.001000 epoch 969/1000 error=0.000998 epoch 970/1000 error=0.000997 epoch 971/1000 error=0.000995 epoch 972/1000 error=0.000995 epoch 973/1000 error=0.000993 epoch 974/1000 error=0.000992 epoch 975/1000 error=0.000990 epoch 976/1000 error=0.000989 epoch 977/1000 error=0.000987 epoch 978/1000 error=0.000986 epoch 979/1000 error=0.000984 epoch 980/1000 error=0.000984 epoch 981/1000 error=0.000982 epoch 982/1000 error=0.000981 epoch 983/1000 error=0.000979 epoch 984/1000 error=0.000978 epoch 985/1000 error=0.000977 epoch 986/1000 error=0.000976 epoch 987/1000 error=0.000974 epoch 988/1000 error=0.000974 epoch 989/1000 error=0.000972 epoch 990/1000 error=0.000971 epoch 991/1000 error=0.000969 epoch 992/1000 error=0.000969 epoch 993/1000 error=0.000967 epoch 994/1000 error=0.000967 epoch 995/1000 error=0.000965 epoch 996/1000 error=0.000965 epoch 997/1000 error=0.000963 epoch 998/1000 error=0.000963 epoch 999/1000 error=0.000961 epoch 1000/1000 error=0.000961
pd.DataFrame(err).plot()
<matplotlib.axes._subplots.AxesSubplot at 0x7f8f771ea6d0>
out = net.predict(x_train)
print(x_train, out)
[[[0 0]] [[0 1]] [[1 0]] [[1 1]]] [array([[0.00043794]]), array([[0.97852756]]), array([[0.97771572]]), array([[-0.00452103]])]
###### Digits recognitin
from keras.datasets import mnist
from keras.utils import np_utils
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.reshape(x_train.shape[0], 1, 28*28).astype('float32')/255
y_train = np_utils.to_categorical(y_train)
x_test = x_test.reshape(x_test.shape[0], 1, 28*28).astype('float32')/255
y_test = np_utils.to_categorical(y_test)
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz 11493376/11490434 [==============================] - 0s 0us/step
y_train[0]
array([0., 0., 0., 0., 0., 1., 0., 0., 0., 0.], dtype=float32)
net = Network()
net.add(FCLayer(28*28, 100))
net.add(ActivationLayer())
net.add(FCLayer(100, 50))
net.add(ActivationLayer())
net.add(FCLayer(50, 10))
net.add(ActivationLayer())
net.use(mse, mse_prime)
errors = net.fit(x_train[0:5000], y_train[0:5000], epochs=35, learning_rate=0.1)
epoch 1/35 error=597.229832 epoch 2/35 error=294.311087 epoch 3/35 error=222.290965 epoch 4/35 error=182.827325 epoch 5/35 error=156.561311 epoch 6/35 error=137.362874 epoch 7/35 error=122.995554 epoch 8/35 error=111.681362 epoch 9/35 error=102.327054 epoch 10/35 error=94.795020 epoch 11/35 error=88.419284 epoch 12/35 error=82.838767 epoch 13/35 error=78.053786 epoch 14/35 error=73.588500 epoch 15/35 error=69.571904 epoch 16/35 error=66.085256 epoch 17/35 error=62.907259 epoch 18/35 error=59.947271 epoch 19/35 error=57.193714 epoch 20/35 error=54.729624 epoch 21/35 error=52.550492 epoch 22/35 error=50.488494 epoch 23/35 error=48.678247 epoch 24/35 error=47.048339 epoch 25/35 error=45.462000 epoch 26/35 error=43.884629 epoch 27/35 error=42.421490 epoch 28/35 error=41.161154 epoch 29/35 error=39.919483 epoch 30/35 error=38.800290 epoch 31/35 error=37.824970 epoch 32/35 error=36.865852 epoch 33/35 error=35.948138 epoch 34/35 error=35.107472 epoch 35/35 error=34.259474
errors=[]
for i in range(1000):
out=sum((net.predict(x_test[i]) - y_test[i])[0][0])
errors.append(0 if out<0.5 else 1)
np.mean(errors)
0.028
Extend Digits network to hhandwritten characters
Use NIST dataset: EMNIST Letters: 145,600 characters. 26 balanced classes. https://www.nist.gov/itl/products-and-services/emnist-dataset