Sam Greydanus | September 2017 | MIT License
from __future__ import print_function
import numpy as np
import matplotlib.pyplot as plt
import torch
from torchvision import datasets, transforms
class Mnist():
def __init__(self, batch_size):
self.batch_size = batch_size
self.modes = modes = ['train', 'test']
trans = transforms.Compose([transforms.ToTensor(),]) # transforms.Normalize((0.1307,), (0.3081,))
dsets = {k: datasets.MNIST('./data', train=k=='train', download=True, transform=trans) for k in modes}
self.loaders = {k: torch.utils.data.DataLoader(dsets[k], batch_size=batch_size, shuffle=True) for k in modes}
def next(self, mode='train'):
X, y = next(iter(self.loaders[mode]))
return X.resize_(self.batch_size, 28**2).numpy(), y.numpy()
def get_accuracy(model, mnist, nsamples=10000, mode='test'):
assert mode in mnist.modes, 'incorrect mode supplied'
assert nsamples >= 10*mnist.batch_size
pool_size, correct = 0, 0
total_correct = 0 ; total = 0 ; acc_list = []
nbatches = int(nsamples/mnist.batch_size)
for _ in range(nbatches):
pool_size += mnist.batch_size
X, y = mnist.next(mode)
y_hat = model(X)
correct += sum(y_hat.argmax(axis=1) == y)
if pool_size > nsamples/10:
acc_list.append(100*correct/pool_size)
total_correct += correct
pool_size, correct = 0, 0
mean, std = 100*total_correct/(nbatches*mnist.batch_size), np.std(acc_list)
return mean, std
def forward(X, model):
# evaluate class scores, [N x K]
saved = {'X': X.copy()}
h1 = np.dot(X, model['W1']) + model['b1'] # linear layer
h1[h1<0] = 0 # relu
saved['h1'] = h1
h2 = np.dot(h1, model['W2']) + model['b2'] # linear layer
h2[h2<0] = 0 # relu
saved['h2'] = h2
e = np.exp(h2 - np.amax(h2, axis=1, keepdims=True))
probs = e / np.sum(e, axis=1, keepdims=True) # softmax cap
return probs, saved
def backward(y, y_hat, model, saved, smodel=None, use_stale_dhs=False):
batch_size = y.shape[0]
grad = { k : np.zeros_like(v) for k,v in model.items() }
# negative log likelihood
nll_loss = -np.log(y_hat[range(batch_size),y]).mean()
y_hat[range(batch_size),y] -= 1 # backwards through softmax
saved['dh2'] = y_hat/batch_size
# second hidden layer
grad['W2'] = np.dot(saved['h1'].T, saved['dh2'])
grad['b2'] = np.sum(saved['dh2'], axis=0, keepdims=True)
# first hidden layer
if smodel is None and not use_stale_dhs:
saved['dh1'] = np.dot(saved['dh2'], model['W2'].T) # backwards through linear layer
saved['dh1'][saved['h1'] <= 0] = 0 # backwards through relu
else:
saved['dh1'] = np.dot(saved['h1'], smodel['W1']) + smodel['b1'] # THE SYNTHETIC GRAD PART
grad['W1'] = np.dot(saved['X'].T, saved['dh1'])
grad['b1'] = np.sum(saved['dh1'], axis=0, keepdims=True)
return nll_loss, grad
def sbackward(saved_bp, saved_synth, smodel, saved):
sgrad = { k : np.zeros_like(v) for k,v in smodel.items() }
d_target = saved_synth['dh1'] - saved_bp['dh1']
sloss = (d_target**2).mean() # l2 loss
sgrad['W1'] += np.dot(saved_bp['h1'].T, d_target)
sgrad['b1'] += np.sum(d_target, axis=0, keepdims=True)
return sloss, sgrad
def train(model, smodel, mnist, args, horizon=20):
loss_hist = [] ; acc_hist = []
rmsprop = { k : np.zeros_like(v) for k,v in model.items() } # rmsprop gradient cache
for step in range(args['max_steps'] + 1):
X, y = mnist.next()
y_hat, saved = forward(X, model)
# get REAL loss and SYNTHETIC grad
loss, grad = backward(y, y_hat.copy(), model, saved, smodel, use_stale_dhs=args['use_synth_grad'])
# upgrade synthetic grad model
if step % args['real_grad_every'] == 0 or step < 100:
saved_synth = saved.copy()
loss, grad = backward(y, y_hat.copy(), model, saved, smodel=None) # get REAL loss and REAL grad
saved_bp = saved.copy()
sloss, sgrad = sbackward(saved_bp, saved_synth, smodel, saved)
# update synthetic model and regular model
smodel = {k : smodel[k] - args['lr']*sgrad[k] for (k,v) in sgrad.items()}
# parameter update
for k,v in model.items():
rmsprop[k] = args['rms_decay'] * rmsprop[k] + (1-args['rms_decay']) * grad[k]**2
if step > 100: model[k] -= args['lr'] * grad[k] / (np.sqrt(rmsprop[k]) + 1e-5)
# bookkeeping
end = '\r'
loss_hist += [loss]
run_loss = sum(loss_hist[-horizon:])/min(horizon, step+1)
if step % args['test_every'] == 0:
get_y_hat = lambda X: forward(X, model)[0]
acc, acc_std = get_accuracy(model=get_y_hat, mnist=mnist)
acc_hist += [[acc, acc_std]] ; end = '\n'
print('step {} | loss {:.4f} | sloss {:.4f} | acc {:.2f}% +/- {:.2f}'\
.format(step, run_loss, sloss, acc, acc_std), end=end)
return model, loss_hist, acc_hist
np.random.seed(1)
args = {'batch_size': 128, 'lr': 1e-4, 'slr': 5e-4,
'rms_decay': 0.9, 'use_synth_grad': True,
'real_grad_every': 10, 'h_size': 128, 'test_every': 500,
'max_steps': 20000, 'save_dir': './save/'}
model = {}
# first layer
model['W1'] = np.random.randn(28**2, args['h_size']) / np.sqrt(args['h_size']) # Xavier initialization
model['b1'] = np.zeros((1,args['h_size']))
model['W2'] = np.random.randn(args['h_size'],10) / np.sqrt(10)
model['b2'] = np.zeros((1,10))
# simplest possible synthetic gradient model
smodel = {}
smodel['W1'] = np.random.randn(args['h_size'], args['h_size']) / 1000*np.sqrt(args['h_size']) # Xavier initialization
smodel['b1'] = np.zeros((1,args['h_size']))
mnist = Mnist(args['batch_size'])
model, loss_hist, acc_hist = train(model, smodel, mnist, args)
step 0 | loss 2.4495 | sloss 0.0056 | acc 14.63% +/- 0.78 step 500 | loss 2.2617 | sloss 0.0013 | acc 20.21% +/- 1.55 step 1000 | loss 1.8909 | sloss 0.0012 | acc 38.99% +/- 1.46 step 1500 | loss 1.5282 | sloss 0.0011 | acc 50.89% +/- 1.63 step 2000 | loss 1.3576 | sloss 0.0009 | acc 56.30% +/- 1.54 step 2500 | loss 1.2664 | sloss 0.0007 | acc 61.39% +/- 1.37 step 3000 | loss 1.1686 | sloss 0.0006 | acc 64.14% +/- 1.91 step 3500 | loss 1.0784 | sloss 0.0005 | acc 65.98% +/- 0.92 step 4000 | loss 1.0126 | sloss 0.0005 | acc 67.65% +/- 1.24 step 4500 | loss 0.9041 | sloss 0.0004 | acc 70.54% +/- 1.19 step 5000 | loss 0.8006 | sloss 0.0003 | acc 72.33% +/- 1.37 step 5500 | loss 0.7861 | sloss 0.0005 | acc 72.95% +/- 0.86 step 6000 | loss 0.6961 | sloss 0.0004 | acc 73.36% +/- 1.09 step 6500 | loss 0.7221 | sloss 0.0038 | acc 75.33% +/- 0.52 step 7000 | loss 0.6787 | sloss 0.0124 | acc 75.77% +/- 1.32 step 7500 | loss 0.6248 | sloss 0.0413 | acc 76.23% +/- 1.04 step 8000 | loss 0.6224 | sloss 2.0212 | acc 77.13% +/- 1.26 step 8500 | loss 0.6301 | sloss 1.0077 | acc 77.38% +/- 1.22 step 9000 | loss 0.6225 | sloss 7.9353 | acc 77.26% +/- 0.592 step 9500 | loss 0.6163 | sloss 2.3852 | acc 78.03% +/- 0.83 step 10000 | loss 0.5827 | sloss 0.3317 | acc 77.82% +/- 0.84 step 10500 | loss 0.5724 | sloss 0.0559 | acc 79.48% +/- 1.16 step 11000 | loss 0.6000 | sloss 0.0349 | acc 78.90% +/- 1.53 step 11500 | loss 0.5917 | sloss 0.4440 | acc 79.05% +/- 1.02 step 12000 | loss 0.5643 | sloss 0.0283 | acc 78.85% +/- 0.92 step 12500 | loss 0.5716 | sloss 0.1773 | acc 79.36% +/- 0.74 step 13000 | loss 0.5374 | sloss 509.7853 | acc 79.27% +/- 0.90 step 13500 | loss 0.5584 | sloss 2787826118.5341 | acc 78.75% +/- 0.89 step 14000 | loss 0.5860 | sloss 1854409339849190.0000 | acc 79.73% +/- 1.49 step 14500 | loss 0.5571 | sloss 1159411052897763852288.0000 | acc 79.31% +/- 0.87 step 15000 | loss 0.5574 | sloss 1217358989235534299765145600.0000 | acc 79.76% +/- 0.88 step 15500 | loss 0.5252 | sloss 1840372574955508694215553781334016.0000 | acc 80.28% +/- 0.99 step 16000 | loss 0.5313 | sloss 34212036349956051780342850725481226960896.0000 | acc 79.79% +/- 0.72 step 16500 | loss 0.5208 | sloss 10155425591099732665021062667856065653389656064.0000 | acc 79.60% +/- 0.71 step 17000 | loss 0.5750 | sloss 920645847401773093735610303148147000305733188190208.0000 | acc 79.32% +/- 0.97 step 17500 | loss 0.5145 | sloss 72335248440776158930355332660249693899755182812177880842240.0000 | acc 80.20% +/- 0.97 step 18000 | loss 0.5317 | sloss 73755422586352825145720696072126228366141058290379419084145033216.0000 | acc 79.60% +/- 0.54 step 18500 | loss 0.5368 | sloss 13748000732113809897378583278923465935126901642073762438818558742888448.0000 | acc 80.00% +/- 0.88 step 19000 | loss 0.5539 | sloss 1959356605377122924897095046989033973094552236455804113301301074853518901248.0000 | acc 79.80% +/- 1.04 step 19500 | loss 0.5593 | sloss 773500764538326937915409912309303195924426661771220380749120353472761551810199552.0000 | acc 80.07% +/- 0.80 step 20000 | loss 0.5573 | sloss 185011477983408857614551473189440333554871558450342719987581792878342280641948336259072.0000 | acc 79.96% +/- 0.92
np.random.seed(1)
args['use_synth_grad'] = False
model = {}
# first layer
model['W1'] = np.random.randn(28**2, args['h_size']) / np.sqrt(args['h_size']) # Xavier initialization
model['b1'] = np.zeros((1,args['h_size']))
model['W2'] = np.random.randn(args['h_size'],10) / np.sqrt(10)
model['b2'] = np.zeros((1,10))
# simplest possible synthetic gradient model
smodel = {}
smodel['W1'] = np.random.randn(args['h_size'], args['h_size']) / 1000*np.sqrt(args['h_size']) # Xavier initialization
smodel['b1'] = np.zeros((1,args['h_size']))
mnist = Mnist(args['batch_size'])
model, loss_hist_cont, acc_hist_cont = train(model, smodel, mnist, args)
step 0 | loss 2.3332 | sloss 0.0054 | acc 15.13% +/- 0.85 step 500 | loss 2.2680 | sloss 0.0012 | acc 21.42% +/- 1.21 step 1000 | loss 1.8428 | sloss 0.0012 | acc 37.95% +/- 1.88 step 1500 | loss 1.5311 | sloss 0.0011 | acc 49.89% +/- 1.23 step 2000 | loss 1.3349 | sloss 0.0009 | acc 55.37% +/- 0.99 step 2500 | loss 1.1120 | sloss 0.0008 | acc 60.99% +/- 1.70 step 3000 | loss 1.0477 | sloss 0.0007 | acc 64.42% +/- 1.60 step 3500 | loss 1.0043 | sloss 0.0006 | acc 66.91% +/- 2.01 step 4000 | loss 0.9293 | sloss 0.0005 | acc 68.50% +/- 0.75 step 4500 | loss 0.8436 | sloss 0.0004 | acc 69.18% +/- 1.09 step 5000 | loss 0.8470 | sloss 0.0003 | acc 71.10% +/- 1.21 step 5500 | loss 0.8111 | sloss 0.0003 | acc 72.49% +/- 1.44 step 6000 | loss 0.7642 | sloss 0.0003 | acc 73.19% +/- 1.38 step 6500 | loss 0.7289 | sloss 0.0002 | acc 74.06% +/- 1.49 step 7000 | loss 0.7315 | sloss 0.0002 | acc 74.08% +/- 0.72 step 7500 | loss 0.6980 | sloss 0.0002 | acc 74.80% +/- 0.71 step 8000 | loss 0.6787 | sloss 0.0002 | acc 75.24% +/- 1.08 step 8500 | loss 0.6377 | sloss 0.0002 | acc 76.07% +/- 1.20 step 9000 | loss 0.6276 | sloss 0.0052 | acc 76.36% +/- 1.36 step 9500 | loss 0.6383 | sloss 25.1805 | acc 77.17% +/- 1.53 step 10000 | loss 0.5604 | sloss 40141.7027 | acc 77.08% +/- 0.82 step 10500 | loss 0.5722 | sloss 507843824.9864 | acc 76.93% +/- 1.51 step 11000 | loss 0.5609 | sloss 253798761869.4677 | acc 77.37% +/- 1.43 step 11500 | loss 0.6289 | sloss 2366971649537018.5000 | acc 76.85% +/- 0.90 step 12000 | loss 0.5633 | sloss 882522396107651678208.0000 | acc 77.68% +/- 1.31 step 12500 | loss 0.5697 | sloss 369619412451849767659503616.0000 | acc 77.85% +/- 0.97 step 13000 | loss 0.5439 | sloss 138499704970837754478723013279744.0000 | acc 78.36% +/- 0.94 step 13500 | loss 0.5033 | sloss 78425756593612332657651219749318688768.0000 | acc 78.71% +/- 1.02 step 14000 | loss 0.5342 | sloss 19230173456607795597025321201083143338065920.0000 | acc 79.06% +/- 0.69 step 14500 | loss 0.4808 | sloss 1538873712770341907664922143162240849767387103232.0000 | acc 78.64% +/- 0.89 step 15000 | loss 0.5346 | sloss 8530774412432267639021495911595377524487086545823596544.0000 | acc 78.63% +/- 1.19 step 15500 | loss 0.5645 | sloss 60178553620118115308721718882150494750616043849307759587622912.0000 | acc 79.69% +/- 0.89 step 16000 | loss 0.5271 | sloss 103164229061339560728897639958945788317896979760118464980100029874176.0000 | acc 79.24% +/- 0.86 step 16500 | loss 0.5359 | sloss 1790083375518323236807304153175456653517113269139663809142786607314926305280.0000 | acc 79.33% +/- 0.50 step 17000 | loss 0.5382 | sloss 85241701181064129221284290712284645811051276688600992159427396585365890026929389568.0000 | acc 79.22% +/- 1.15 step 17500 | loss 0.5643 | sloss 17689833724278372189125314093754779848062468157819256136246929496743436995286061998407680.0000 | acc 79.29% +/- 0.99 step 18000 | loss 0.4883 | sloss 37171949552046653534737151024844850577939170279305387855599086339270910241828492234013850730496.0000 | acc 79.77% +/- 1.32 step 18500 | loss 0.5066 | sloss 805638129748647807254721630594128144628200331656860431232299554568127757018410461062150169280118784.0000 | acc 79.10% +/- 1.45 step 19000 | loss 0.5434 | sloss 132707244988259347976950645930190439456890128019301151107010018148853164185758627251638181025703408959488.0000 | acc 79.04% +/- 0.71 step 19500 | loss 0.5604 | sloss 67292779817348793341164876982079855644322202012905413758540678056064225686564206160436726637686272433575690240.0000 | acc 79.78% +/- 1.24 step 20000 | loss 0.4841 | sloss 103547975802604524945139506381020614893573561035082067514121697752087316910937705117244521227357949931598627647520768.0000 | acc 79.76% +/- 0.55
loss_data = np.vstack(loss_hist)
acc_data = np.vstack(acc_hist)
loss_data_cont = np.vstack(loss_hist_cont)
acc_data_cont = np.vstack(acc_hist_cont)
f = plt.figure(figsize=[12,4])
plt.subplot(1,2,1) ; plt.title("Train loss") ; plt.xlabel("steps")
plt.plot(range(len(loss_hist)), loss_data, 'orange', label='synthetic')
plt.plot(range(len(loss_hist_cont)), loss_data_cont, 'blue', label='control')
plt.ylim([0,3])
plt.legend()
plt.subplot(1,2,2) ; plt.title("Test accuracy") ; plt.xlabel("steps")
plt.errorbar(range(len(loss_hist))[::500], acc_data[:,0], yerr=acc_data[:,1], errorevery=1, label='synthetic')
plt.errorbar(range(len(loss_hist_cont))[::500], acc_data_cont[:,0], yerr=acc_data_cont[:,1], errorevery=1, label='control')
plt.ylim([25,100])
plt.legend()
plt.show() ; f.savefig('./static/synthetic.png', bbox_inches='tight')
X, y = mnist.next()
y_hat, saved = forward(X, model)
predictions = y_hat.argmax(axis=1)
rows = 2
cols = 5
side = 2
f = plt.figure(figsize=[cols*side,rows*side])
for r in range(rows):
for c in range(cols):
img_ix = r*cols + c
plt.subplot(rows, cols, img_ix+1)
plt.title("pred: {}".format(predictions[img_ix]))
plt.imshow(X[img_ix].reshape(28,28), cmap='gray')
f.axes[img_ix].get_xaxis().set_visible(False)
f.axes[img_ix].get_yaxis().set_visible(False)
plt.show()