import random
import itertools
import jax
import jax.numpy as np
# Current convention is to import original numpy as "onp"
import numpy as onp
# Sigmoid nonlinearity
def sigmoid(x):
return 1 / (1 + np.exp(-x))
# Computes our network's output
def net(params, x):
w1, b1, w2, b2 = params
hidden = np.tanh(np.dot(w1, x) + b1)
return sigmoid(np.dot(w2, hidden) + b2)
# Cross-entropy loss
def loss(params, x, y):
out = net(params, x)
cross_entropy = -y * np.log(out) - (1 - y)*np.log(1 - out)
return cross_entropy
# Utility function for testing whether the net produces the correct
# output for all possible inputs
def test_all_inputs(inputs, params):
predictions = [int(net(params, inp) > 0.5) for inp in inputs]
for inp, out in zip(inputs, predictions):
print(inp, '->', out)
return (predictions == [onp.bitwise_xor(*inp) for inp in inputs])
def initial_params():
return [
onp.random.randn(3, 2), # w1
onp.random.randn(3), # b1
onp.random.randn(3), # w2
onp.random.randn(), #b2
]
loss_grad = jax.grad(loss)
# Stochastic gradient descent learning rate
learning_rate = 1.
# All possible inputs
inputs = np.array([[0, 0], [0, 1], [1, 0], [1, 1]])
# Initialize parameters randomly
params = initial_params()
for n in itertools.count():
# Grab a single random input
x = inputs[onp.random.choice(inputs.shape[0])]
# Compute the target output
y = onp.bitwise_xor(*x)
# Get the gradient of the loss for this input/output pair
grads = loss_grad(params, x, y)
# Update parameters via gradient descent
params = [param - learning_rate * grad
for param, grad in zip(params, grads)]
# Every 100 iterations, check whether we've solved XOR
if not n % 100:
print('Iteration {}'.format(n))
if test_all_inputs(inputs, params):
break
Iteration 0 [0 0] -> 0 [0 1] -> 0 [1 0] -> 0 [1 1] -> 0 Iteration 100 [0 0] -> 0 [0 1] -> 1 [1 0] -> 1 [1 1] -> 1 Iteration 200 [0 0] -> 0 [0 1] -> 1 [1 0] -> 1 [1 1] -> 0
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
key = random.PRNGKey(0)
x = random.normal(key, (10,))
print(x)
[-0.3721109 0.26423115 -0.18252768 -0.7368197 -0.44030377 -0.1521442 -0.67135346 -0.5908641 0.73168886 0.5673026 ]
size = 3000
x = random.normal(key, (size, size), dtype=jnp.float32)
%timeit jnp.dot(x, x.T).block_until_ready() # runs on the GPU
1 loop, best of 3: 931 ms per loop
import numpy as np
x = np.random.normal(size=(size, size)).astype(np.float32)
%timeit jnp.dot(x, x.T).block_until_ready()
1 loop, best of 3: 938 ms per loop
from jax import device_put
x = np.random.normal(size=(size, size)).astype(np.float32)
x = device_put(x)
%timeit jnp.dot(x, x.T).block_until_ready()
1 loop, best of 3: 903 ms per loop
x = np.random.normal(size=(size, size)).astype(np.float32)
%timeit np.dot(x, x.T)
1 loop, best of 3: 452 ms per loop
size = 3000
x = random.normal(key, (size, size), dtype=jnp.float32)
%timeit jnp.dot(x, x.T).block_until_ready() # runs on the GPU
The slowest run took 227.61 times longer than the fastest. This could mean that an intermediate result is being cached. 1 loop, best of 3: 10.1 ms per loop
import numpy as np
x = np.random.normal(size=(size, size)).astype(np.float32)
%timeit jnp.dot(x, x.T).block_until_ready()
10 loops, best of 3: 34.9 ms per loop
from jax import device_put
x = np.random.normal(size=(size, size)).astype(np.float32)
x = device_put(x)
%timeit jnp.dot(x, x.T).block_until_ready()
10 loops, best of 3: 20.3 ms per loop
x = np.random.normal(size=(size, size)).astype(np.float32)
%timeit np.dot(x, x.T)
1 loop, best of 3: 324 ms per loop
def selu(x, alpha=1.67, lmbda=1.05):
return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)
x = random.normal(key, (1000000,))
%timeit selu(x).block_until_ready()
The slowest run took 159.08 times longer than the fastest. This could mean that an intermediate result is being cached. 1 loop, best of 3: 2.36 ms per loop
selu_jit = jit(selu)
%timeit selu_jit(x).block_until_ready()
The slowest run took 433.81 times longer than the fastest. This could mean that an intermediate result is being cached. 1000 loops, best of 3: 270 µs per loop
x
DeviceArray([ 1.99376 , 0.20781846, -0.34406224, ..., 0.03467206, 0.7103182 , 0.1965587 ], dtype=float32)