This is a two-layer neural network trained with synthetic gradients as in Decoupled Neural Interfaces using Synthetic Gradients. While the code is not an exact implementation, it captures the major ideas of the paper in a concrete way.
The basic idea behind synthetic gradients is to predict the gradients on all parameters in a neural network using a second network. For a really good conceptual introduction, read the DeepMind blog post]
import numpy as np
import copy
import matplotlib.pyplot as plt
import matplotlib.cm as cm
%matplotlib inline
from tensorflow.examples.tutorials.mnist import input_data # just use tensorflow's mnist api
mnist = input_data.read_data_sets('MNIST_data', one_hot=False)
Extracting MNIST_data/train-images-idx3-ubyte.gz Extracting MNIST_data/train-labels-idx1-ubyte.gz Extracting MNIST_data/t10k-images-idx3-ubyte.gz Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
global_step = 0
batch_size = 10
lr = 2e-2 # model learning rate
slr = 1e-4 # synthetic learning rate
h1_size = 100 # first hidden layer
h2_size = 10 # second hidden layer
D = 28*28 # dimensionality
synth_step = 10 # ratio of model updates to synthetic gradient updates
def make_model():
model = {}
# first layer
model['W1'] = np.random.randn(D,h1_size) / np.sqrt(h1_size) # Xavier initialization
model['b1'] = np.random.randn(1,h1_size) / np.sqrt(h1_size) # see http://jmlr.org/proceedings/papers/v9/glorot10a/glorot10a.pdf
#second layer
model['W2'] = np.random.randn(h1_size,h2_size) / np.sqrt(h2_size)
model['b2'] = np.random.randn(1,h2_size) / np.sqrt(h2_size)
return model
synth = make_model() ; stale = make_model() ; control = make_model()
# model for predicting synthetic gradients
smodel = { k : np.random.randn(v.shape[1], np.prod(v.shape)) * 0.1 / np.sqrt(np.prod(v.shape)) \
for k,v in synth.iteritems() }
# model functions
def xW_plus_b(x, W, b):
return np.dot(x,W) + b # in some cases you can even drop the bias b
def softmax(x):
maxes = np.amax(x, axis=1, keepdims=True)
e = np.exp(x - maxes) # improves numerics
dist = e / np.sum(e, axis=1, keepdims=True)
return dist
def relu(x):
x[x<0] = 0
return x
# derivatives of model functions
# see https://nbviewer.jupyter.org/github/greydanus/np_nets/blob/master/mnist_nn.ipynb for derivations
def dsoftmax(h, y, batch_size):
h[range(batch_size),y] -= 1
return h/y.shape[0] # divide by batch size
def drelu(dz, h):
dz[h <= 0] = 0 # backprop relu
return dz
def dxW_plus_b(dh, model):
return np.dot(dh, model['W2'].T)
def forward(X, model):
# evaluate class scores, [N x K]
hs = [] # we'll need the h's for computing gradients
h1 = relu(np.dot(X, model['W1']) + model['b1']) ; hs.append(h1)
h2 = relu(np.dot(h1, model['W2']) + model['b2']); hs.append(h2)
probs = softmax(h2)
return probs, hs
def backward(y, probs, hs, model):
grads = { k : np.zeros_like(v) for k,v in model.iteritems() }
dh2 = dsoftmax(probs, y, batch_size)
# second hidden layer
grads['W2'] = np.dot(hs[0].T, dh2)
grads['b2'] = np.sum(dh2, axis=0, keepdims=True)
# first hidden layer
dh1 = dxW_plus_b(dh2, model)
dh1 = drelu(dh1, hs[0]) # backprop through relu
grads['W1'] = np.dot(X.T, dh1)
grads['b1'] = np.sum(dh1, axis=0, keepdims=True)
return grads
# forward propagate synthetic grad model
def sforward(hs, smodel, model):
synthetic_grads = { k : np.zeros_like(v) for k,v in smodel.iteritems() }
synthetic_grads['W1'] = np.dot(hs[0], smodel['W1'])
synthetic_grads['b1'] = np.dot(hs[0], smodel['b1'])
synthetic_grads['W2'] = np.dot(hs[1], smodel['W2'])
synthetic_grads['b2'] = np.dot(hs[1], smodel['b2'])
return synthetic_grads
# backward propagate synthetic grad model
def sbackward(hs, ds, smodel):
sgrads = { k : np.zeros_like(v) for k,v in smodel.iteritems() }
sgrads['W2'] = np.dot(hs[1].T, ds['W2'])
sgrads['b2'] = np.dot(hs[1].T, ds['b2'])
sgrads['W1'] = np.dot(hs[0].T, ds['W1'])
sgrads['b1'] = np.dot(hs[0].T, ds['b1'])
return sgrads
# evaluate training set accuracy
def eval_model(model):
X = mnist.test.images
y = mnist.test.labels
hidden_layer = np.maximum(0, np.dot(X, model['W1']) + model['b1'])
scores = np.dot(hidden_layer, model['W2']) + model['b2']
predicted_class = np.argmax(scores, axis=1)
return(np.mean(predicted_class == y))
synth_history = []
smoothing_factor = 0.95
for i in xrange(global_step, 7500):
X, y = mnist.train.next_batch(batch_size)
probs, hs = forward(X, synth)
synthetic_grads = sforward(hs, smodel, synth) # compute synthetic gradients
# synthetic gradient model updates
if i % synth_step == 0:
# compute the loss
y_logprobs = -np.log(probs[range(batch_size),y])
loss = np.sum(y_logprobs)/batch_size
if i is 0 : smooth_loss = loss
grads = backward(y, probs, hs, synth) # data gradients
# compute the synthetic gradient loss
ds = {k : - v.ravel() + synthetic_grads[k] for (k, v) in grads.iteritems()}
squared_errors = [np.sum(slr*v*v) for v in ds.itervalues()]
sloss = np.sum(squared_errors)/batch_size
sgrads = sbackward(hs, ds, smodel)
smodel = {k : smodel[k] - slr*sgrads[k] for (k,v) in sgrads.iteritems()} # update smodel parameters
# boring book-keeping
smooth_loss = smoothing_factor*smooth_loss + (1-smoothing_factor)*loss
synth_history.append((i,smooth_loss))
if (i+1) % 1000 == 0:
print "iteration {}: test accuracy {:3f}".format(i, eval_model(synth))
if (i) % 250 == 0:
print "\titeration {}: smooth_loss {:3f}, synth_loss {:3f}".format(i, smooth_loss, sloss)
# estimate gradients using synthetic gradient model
est_grad = {k : np.reshape(np.sum(v, axis=0), grads[k].shape)/batch_size for k,v in synthetic_grads.iteritems()}
synth = {k : synth[k] - lr*v for (k,v) in est_grad.iteritems()} # update model using estimated gradient
global_step += 1
iteration 0: smooth_loss 3.815237, synth_loss 0.009213 iteration 250: smooth_loss 2.493766, synth_loss 0.002045 iteration 500: smooth_loss 1.379297, synth_loss 0.000673 iteration 750: smooth_loss 1.012663, synth_loss 0.001725 iteration 999: test accuracy 0.738500 iteration 1000: smooth_loss 1.057044, synth_loss 0.000423 iteration 1250: smooth_loss 0.777043, synth_loss 0.001821 iteration 1500: smooth_loss 0.607534, synth_loss 0.000836 iteration 1750: smooth_loss 0.615273, synth_loss 0.001889 iteration 1999: test accuracy 0.808900 iteration 2000: smooth_loss 0.584410, synth_loss 0.001935 iteration 2250: smooth_loss 0.579823, synth_loss 0.000686 iteration 2500: smooth_loss 0.517544, synth_loss 0.000147 iteration 2750: smooth_loss 0.591252, synth_loss 0.001182 iteration 2999: test accuracy 0.872600 iteration 3000: smooth_loss 0.504129, synth_loss 0.000169 iteration 3250: smooth_loss 0.560493, synth_loss 0.001131 iteration 3500: smooth_loss 0.454970, synth_loss 0.000051 iteration 3750: smooth_loss 0.431393, synth_loss 0.001128 iteration 3999: test accuracy 0.894100 iteration 4000: smooth_loss 0.441746, synth_loss 0.000462 iteration 4250: smooth_loss 0.417226, synth_loss 0.000246 iteration 4500: smooth_loss 0.435450, synth_loss 0.000374 iteration 4750: smooth_loss 0.369929, synth_loss 0.000181 iteration 4999: test accuracy 0.907100 iteration 5000: smooth_loss 0.327477, synth_loss 0.000486 iteration 5250: smooth_loss 0.254708, synth_loss 0.000769 iteration 5500: smooth_loss 0.282602, synth_loss 0.000079 iteration 5750: smooth_loss 0.285151, synth_loss 0.000702 iteration 5999: test accuracy 0.900400 iteration 6000: smooth_loss 0.301963, synth_loss 0.000240 iteration 6250: smooth_loss 0.348865, synth_loss 0.000360 iteration 6500: smooth_loss 0.331123, synth_loss 0.000797 iteration 6750: smooth_loss 0.298193, synth_loss 0.000218 iteration 6999: test accuracy 0.909100 iteration 7000: smooth_loss 0.308105, synth_loss 0.000059 iteration 7250: smooth_loss 0.317047, synth_loss 0.000645
stale_history = []
global_step = 0
smoothing_factor = 0.95
for i in xrange(global_step, 7500):
X, y = mnist.train.next_batch(batch_size)
probs, hs = forward(X, stale)
# calculate gradients
if i % synth_step == 0:
# compute the loss
y_logprobs = -np.log(probs[range(batch_size),y])
loss = np.sum(y_logprobs)/batch_size
if i is 0 : smooth_loss = loss
grads = backward(y, probs, hs, stale) # data gradients
# boring book-keeping
smooth_loss = smoothing_factor*smooth_loss + (1-smoothing_factor)*loss
stale_history.append((i,smooth_loss))
if (i+1) % 1000 == 0:
print "iteration {}: test accuracy {:3f}".format(i, eval_model(stale))
if (i) % 250 == 0:
print "\titeration {}: smooth_loss {:3f}, stale_loss {:3f}".format(i, smooth_loss, sloss)
# update model using stale gradients
stale = {k : stale[k] - lr*v for (k,v) in grads.iteritems()}
global_step += 1
iteration 0: smooth_loss 3.030677, stale_loss 0.000600 iteration 250: smooth_loss 2.297427, stale_loss 0.000600 iteration 500: smooth_loss 1.606829, stale_loss 0.000600 iteration 750: smooth_loss 1.263941, stale_loss 0.000600 iteration 999: test accuracy 0.728900 iteration 1000: smooth_loss 1.055426, stale_loss 0.000600 iteration 1250: smooth_loss 0.813792, stale_loss 0.000600 iteration 1500: smooth_loss 0.761992, stale_loss 0.000600 iteration 1750: smooth_loss 0.631247, stale_loss 0.000600 iteration 1999: test accuracy 0.808100 iteration 2000: smooth_loss 0.619193, stale_loss 0.000600 iteration 2250: smooth_loss 0.689113, stale_loss 0.000600 iteration 2500: smooth_loss 0.654125, stale_loss 0.000600 iteration 2750: smooth_loss 0.602515, stale_loss 0.000600 iteration 2999: test accuracy 0.844400 iteration 3000: smooth_loss 0.567582, stale_loss 0.000600 iteration 3250: smooth_loss 0.572366, stale_loss 0.000600 iteration 3500: smooth_loss 0.524966, stale_loss 0.000600 iteration 3750: smooth_loss 0.444738, stale_loss 0.000600 iteration 3999: test accuracy 0.856900 iteration 4000: smooth_loss 0.429342, stale_loss 0.000600 iteration 4250: smooth_loss 0.325885, stale_loss 0.000600 iteration 4500: smooth_loss 0.444577, stale_loss 0.000600 iteration 4750: smooth_loss 0.451115, stale_loss 0.000600 iteration 4999: test accuracy 0.876700 iteration 5000: smooth_loss 0.461882, stale_loss 0.000600 iteration 5250: smooth_loss 0.443143, stale_loss 0.000600 iteration 5500: smooth_loss 0.383808, stale_loss 0.000600 iteration 5750: smooth_loss 0.401103, stale_loss 0.000600 iteration 5999: test accuracy 0.885000 iteration 6000: smooth_loss 0.417234, stale_loss 0.000600 iteration 6250: smooth_loss 0.368787, stale_loss 0.000600 iteration 6500: smooth_loss 0.358992, stale_loss 0.000600 iteration 6750: smooth_loss 0.432486, stale_loss 0.000600 iteration 6999: test accuracy 0.888500 iteration 7000: smooth_loss 0.458806, stale_loss 0.000600 iteration 7250: smooth_loss 0.403893, stale_loss 0.000600
# first, flatten all gradients and estimated gradients
g, s = np.zeros((0)), np.zeros((0))
for (k,v) in grads.iteritems():
g = np.concatenate((g,copy.deepcopy(v.ravel())))
s = np.concatenate((s,copy.deepcopy(est_grad[k].ravel())))
# now keep only those which are fairly large and see if their signs agree
eps = 1e-3
s[np.abs(s)<eps] = 0 ; g[np.abs(g)<eps] = 0 # ignore deviations of epsilon around 0
num_zeros = sum(s.ravel() * g.ravel() == 0)
num_matching_signs = sum(s.ravel() * g.ravel() > 0)
percentage = float (num_matching_signs) / (len(g) - num_zeros) * 100
print "{:2f}% same sign out of {} nonzero gradients".format(percentage, len(g) - num_zeros)
# plot first 2500 gradients
print "Plotting grads vs est_grads for W2 parameters"
sample = np.random.randint(len(g), size=500)
plt.figure(figsize=(12,16))
plt.subplot(121) ; plt.title("Actual gradient W2")
plt.imshow(grads['W2'].T[:,:50], cmap=cm.jet)
plt.subplot(122) ; plt.title("Synthetic gradient W2")
plt.imshow(est_grad['W2'].T[:,:50], cmap=cm.jet)
plt.show()
45.071011% same sign out of 4788 nonzero gradients Plotting grads vs est_grads for W2 parameters
plt.figure(figsize=(8,6))
plt.title("Smoothed loss")
plt.xlabel("training steps")
plt.ylabel("loss")
train_steps, synth_losses = zip(*synth_history[1000/synth_step:])
train_steps, stale_losses = zip(*stale_history[1000/synth_step:])
synth_line, = plt.plot(train_steps, synth_losses, label='Synthetic gradients')
stale_line, = plt.plot(train_steps, stale_losses, label='Stale gradients')
plt.legend(handles=[synth_line, stale_line])
plt.show()