Example Notebook for using Stochastic Gradient Fisher Scoring Step Method - It is based on this
The goal here is to learn from small noisy samples of a quadratic function and estimate the parameters of the function.
%matplotlib inline
import functools
import numpy as np
from theano import theano, tensor as tt
import matplotlib.pyplot as plt
import pymc3 as pm
np.random.seed(42)
Consider a simple quadratic problem with unknown parameters.
def f(x, a, b, c):
return a*x**2 + b*x + c
a, b, c = 1, 2, 3
min_ = np.array([-b/2/a])
Generate the quadaratic data with noise
batch_size = 10
total_size = batch_size*100
def xy_obs_generator(batch_size):
while True:
x_obs = np.random.uniform(-5, 5, size=(batch_size,)).astype('float32')
result = np.asarray([x_obs, f(x_obs, a, b, c) + np.random.normal(size=x_obs.shape).astype('float32')])
yield result
x_train = np.random.uniform(-5, 5, size=(total_size,)).astype('float32')
x_obs = pm.Minibatch(x_train, batch_size=batch_size)
# xy_obs = pm.generator(xy_obs_generator(batch_size))
y_train = f(x_train, a, b, c) + np.random.normal(size=x_train.shape).astype('float32')
y_obs = pm.Minibatch(y_train, batch_size=batch_size)
# Example observation
# obs = xy_obs.eval()
plt.plot(x_train, y_train, 'bo');
Our task is to find to estimate the quadratic function. We will test this by calculating L2 loss of the generated samples from the trained model to the observed output
burn_in = 50
draws = 1000
with pm.Model() as model:
abc = pm.Normal('abc', sd=1, mu=1, shape=(3,))
x = x_obs
x2 = x**2
o = tt.ones_like(x)
X = tt.stack([x2, x, o]).T
y = X.dot(abc)
pm.Normal('y', mu=y, observed=y_obs, total_size=total_size)
step_method = pm.SGFS(batch_size=batch_size, step_size=1.0, total_size=total_size)
trace = pm.sample(draws=draws, step=step_method, init=None)
pm.traceplot(trace[burn_in:]);
/home/laoj/Documents/Github/pymc3/pymc3/step_methods/sgmcmc.py:107: UserWarning: Warning: Stochastic Gradient based sampling methods are experimental step methods and not yet recommended for use in PyMC3! warnings.warn(EXPERIMENTAL_WARNING) 100%|██████████| 1500/1500 [00:00<00:00, 2203.90it/s]
We don't have exact function as our observations were noisy
trace['abc'].mean(axis=0)
array([ 0.99583151, 1.9845134 , 3.13419727])