Adapated from the pymc coal mining disaster example
import numpy as np
import matplotlib.pyplot as plt
from pymc3 import Model, Normal, HalfNormal,sampling
from pymc3 import find_MAP
from pymc3 import NUTS, sample
from scipy import optimize
from pymc3 import traceplot, summary
%matplotlib inline
/Users/kersten/anaconda/envs/pymc3/lib/python2.7/site-packages/matplotlib/font_manager.py:273: UserWarning: Matplotlib is building the font cache using fc-list. This may take a moment. warnings.warn('Matplotlib is building the font cache using fc-list. This may take a moment.')
from pymc3 import Exponential, StudentT, Deterministic
from pymc3.math import exp
from pymc3.distributions.timeseries import GaussianRandomWalk
x = np.linspace(1, 1000, 1000)
spikes = np.piecewise(x, [x < 453, x >= 453], [35, 39])
noise = np.ceil(4.*np.random.randn(1000))
spikes_data = spikes + noise
spikes_data[999]=-1 ;
rate_data = np.ma.masked_values(spikes_data, value = -1)
time = np.arange(0, 1000)
plt.plot(time, rate_data, 'o', markersize=1);
plt.ylabel("Rate")
plt.xlabel("Time")
<matplotlib.text.Text at 0x110c6b550>
np.shape(rate_data)
(1000,)
from pymc3 import DiscreteUniform, Poisson
from pymc3.math import switch
with Model() as rate_model:
switchpoint = DiscreteUniform('switchpoint', lower=time.min(), upper=time.max(), testval=500)
# Priors for pre- and post-switch rates number of spikes
early_rate = Exponential('early_rate', 1./25)
late_rate = Exponential('late_rate', 1./25)
# Allocate appropriate Poisson rates to time before and after current
rate = switch(switchpoint >= time, early_rate, late_rate)
spikes = Poisson('spikes', rate, observed=rate_data)
Applied log-transform to early_rate and added transformed early_rate_log_ to model. Applied log-transform to late_rate and added transformed late_rate_log_ to model.
import theano.tensor as T
$f(x \mid \lambda) = \lambda \exp\left\{ -\lambda x \right\}$
rate = switch(switchpoint >= time, early_rate, late_rate)
switchpoint_print = T.printing.Print('switchpoint')(switchpoint)
early_rate_print = T.printing.Print('early_rate')(early_rate)
late_rate_print = T.printing.Print('late_rate')(late_rate)
switchpoint __str__ = 500 early_rate __str__ = 17.3286789401 late_rate __str__ = 17.3286789401
from pymc3 import Metropolis
with rate_model:
step1 = NUTS([early_rate, late_rate])
# Use Metropolis for switchpoint, and missing values since it accommodates discrete variables
step2 = Metropolis([switchpoint, spikes.missing_values[0]])
trace = sample(1000, step=[step1, step2])
100%|██████████| 1000/1000 [00:01<00:00, 862.11it/s]
traceplot(trace);
true switchpoint is at 453. true early mean is 39, and late mean is 35
ppc = sampling.sample_ppc(trace, samples=1000, model=rate_model, size=10)