Notes for Why does deep and cheap learning work so well? (ArXiv:1608.08225v1/cond-mat.dis-nn) by Lin and Tegmark.
Let's implement a simple multiplication network as per figure 2 in the paper. Obviously this could be done much more efficiently with any of the deep-learning-in-a-box packages that exist today. My interest is strictly to gain some intuition about this result, so I'm doing it all by hand here.
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
plt.style.use('ggplot')
plt.rcParams['figure.figsize'] = 10, 6
plt.rcParams['axes.facecolor'] = "0.92"
def σ(y):
"Logistic sigmoid"
return 1/(1+np.exp(-y))
y = np.linspace(-10, 10)
plt.plot(y, σ(y));
We need σ″(0). Analytically I'm getting
σ″(y)=e−ye−y−1(e−y+1)3I'm rusty, so let's double check my algebra with SymPy:
import sympy as S
S.init_printing(use_latex=True)
y = S.symbols('y')
s = 1/(1+S.exp(-y))
s1 = S.diff(s, y)
s1
s2 = S.simplify(S.diff(s1, y))
s2
These appear to be different, but it's b/c sympy simplified it to use ey and I computed it with e−y. Since I'm lazy, let's have sympy check that the two forms are identical (the algebra is trivial but it's late and I'm tired):
ey = S.exp(-y)
s2f = ey*(ey-1)/(ey+1)**3
s2f
S.simplify(s2 - s2f)
Indeed.
As per the paper, μ must be defined as
μ=λ−28σ″(0)Note that in order to construct the full network, we need shift the origin of σ by 1 to ensure σ″(0)≠0 so μ can be finite. We can do this by adding a bias term b=[1,1,1,1], as described in the paper. This means we need to evaluate σ″(1) instead:
S.init_printing(use_latex=False) # Turn it off again so it doesn't slow everything down.
# Let's evaluate the 2nd derivative at the origin (shifted to y=1), numerically
# (so we can use it below in the network). Since the the point of evaluation of σ'' must be consistent
# with the definition of the vector b below, let's store it in a variable we also use to construct b
σ_origin = 1
s2_ori = float(s2.subs(y, σ_origin).n())
s2_ori
-0.09085774767294841
Let's define the generic form of an affine layer. To support proper composition in a multi-layer network, we need to express the affine layer as a function:
def A(W, b):
"Affine layer"
return lambda y: W @ y + b
Now, let's construct the full network in Fig. 2 (left side). It should converge to the multiplication operator as λ→0.
Warning: I am having to redefine μ→2×μ, in order to get the right numerical results. It may be an error in the paper's algebra or in my implementation, but I haven't had the time to track it down yet. Feedback/hints welcome.
# The connection matrices for the two affine layers. These are just constants
# that only need to be defined once
w1 = np.array([1.0, 1, -1, -1, 1, -1, -1, 1]).reshape(4,2)
b = σ_origin * np.array([1.0, 1, 1, 1]) # The bias shift to avoid evaluating σ''(0)
w2 = np.array([1.0, 1, -1, -1])
λ = 0.00001 # as it -> 0, the multiply() function improves in accuracy
μ = 1.0/(8*s2_ori*λ**2) # from formula in fig. 2
# The actual matrices carry λ and μ:
W1 = λ*w1
W2 = 2*μ*w2 # This factor of 2 is not in the paper. Error in my algebra?
# Now we build the affine layers as functions
A1 = A(W1, b)
A2 = A(W2, 0)
# With these in place, we can then build the 3-layer network that approximates
# the multiplication operator
def multiply(u, v):
"Multiply two numbers with a neural network."
y = np.array([u, v], dtype=float)
return A2(σ(A1(y)))
# Let's verify it with two numbers
u, v = 376, 432
uvn = multiply(u, v)
uv = u*v
err = abs(uv - uvn)
print("λ :", λ)
print("Network u*v:", uvn)
print("Exact u*v :", uv)
print("Abs. Error : %.2g" % err)
print("Rel. Error : %.2g" % (err/uv))
λ : 1e-05 Network u*v: 162430.792953 Exact u*v : 162432 Abs. Error : 1.2 Rel. Error : 7.4e-06
Let's have a quick look at how the approximation converges. From eq. (8) in the paper, we expect the error to be O(λ2(u2+v2)). Note that the equation in the paper states the error without the extra λ2 factor, but that's because it's making an analysis assuming |u|≪1,|v|≪1, where as once implemented we relax this restriction by rescaling u and v by λ via the first affine layer.
In order to conveniently scan over values of λ, u and v, it will be helpful to encapsulate the above construction into a callable object that precomputes all relevant quantities at construction time (for each λ) and then can be quickly called:
class NNMultiply:
def __init__(self, λ=1e-5):
self.λ = λ
μ = 1.0/(8*s2_ori*λ**2)
self.W1 = λ*w1
self.W2 = 2*μ*w2 # This factor of 2 is not in the paper. Error in my algebra?
self.b = σ_origin * np.array([1.0, 1, 1, 1]) # The bias shift to avoid evaluating σ''(0)
def __call__(self, u, v):
"Multiply two numbers with a neural network."
y = np.array([u, v], dtype=float)
# Since we'll be calling this a lot, let's make a small optimization and "unroll"
# our network to avoid a few unnecessary function calls
return self.W2 @ σ(self.W1 @ y + self.b)
# Let's verify it with the same values as above, as a sanity check
u, v = 376, 432
mult = NNMultiply(λ)
uvn = mult(u, v)
uv = u*v
err = abs(uv - uvn)
print("λ :", mult.λ)
print("Network u*v:", uvn)
print("Exact u*v :", uv)
print("Abs. Error : %.2g" % err)
print("Rel. Error : %.2g" % (err/uv))
λ : 1e-05 Network u*v: 162430.792953 Exact u*v : 162432 Abs. Error : 1.2 Rel. Error : 7.4e-06
Now, we can build an error plot at various values of λ. Since the function we're approximating is simply uv, we can keep u constant and only scan over v for each λ, as long as we cover a range that goes from u≪v to u≫v.
In the figure below, next to each line for a given λ and in the same color, is a dashed line that plots λ2(u2+v2), which should be (modulo a constant) a good estimate of the observed error.
u = 10
lambdas = np.logspace(-1, -7, 7)
vv = np.linspace(0.1, 100, 50)
fig, ax = plt.subplots()
for λ in lambdas:
mult = NNMultiply(λ)
err = []
for v in vv:
uvn = mult(u, v)
uv = u*v
err.append(abs(uv - uvn)/uv)
l, = ax.semilogy(vv, err, label=r"$\lambda =$ %.2g" % λ)
ax.semilogy(vv, (λ**2)*(u**2+vv**2), '--', color=l.get_color())
ax.set_xlabel('v')
ax.set_ylabel('Rel. error')
ax.set_title("u fixed at %g" % u)
ax.legend();
As we see, once λ<10−5, we start hitting some numerical issues, and below λ<10−6, the error is not only worse than the analytical estimate, it actually starts getting worse as λ gets smaller. This is because μ∼1/λ2, and in double precision we don't have enough digits to go further.