Probabilistic Programming 4: Latent Variables & Dynamic Models

Goal

  • Understand how to estimate parameters for latent variable models.
  • Understand how to estimate states in dynamical models.

Materials

Note that none of the material below is new. The point of the Probabilistic Programming sessions is to solve practical problems so that the concepts from Bert's lectures become less abstract.

In [1]:
using Pkg
Pkg.activate("./workspace")
Pkg.instantiate()
IJulia.clear_output();
In [2]:
using JLD
using Statistics
using StatsBase
using LinearAlgebra
using ProgressMeter
using LaTeXStrings
using ForneyLab
using Plots
pyplot();

using LinearAlgebra: I
using ForneyLab: unsafeMean
include("../scripts/clusters.jl");
include("../scripts/filters.jl");

Problem: Stone Tools

Archeologists have asked for your help in analyzing data on stone tools. It is believed that primitive humans created tools by striking stones with others. During this process, the stone loses flakes, which have been preserved. The archeologists have recovered these flakes from various locations and time periods and want to know whether this stone tool shaping process has improved over the centuries.

Data

The data is available from the UCI Machine Learning Repository. Each instance represents summary information of the stone flakes for a particular site. We will be using the attributes flaking angle (FLA) and the proportion of the dorsal surface worked (PROZD) for now.

In [3]:
dataset = load("../datasets/stoneflakes.jld");

I've done some pre-processing on the data set, namely z-scoring and removing two outliers. This reduces the scale of the attributes which helps numerical stability during optimization.

Now let's visualize the data with a scatterplot.

In [4]:
scatter(dataset["data"][:,1], 
        dataset["data"][:,2], 
        label="", 
        xlabel="Proportion of worked dorsal surface (PROZD)",
        ylabel="Flaking angle (FLA)")
Out[4]:

Model specification

We will be clustering this data with a Gaussian mixture model, to see if we can identify clear types of stone tools. The generative model for a Gaussian mixture consists of:

$$ p(X, z, \mu, \Lambda) =\ \underbrace{p(X \mid z, \mu, \Lambda)}_{\text{likelihood}}\ \times \ \underbrace{p(z \mid \phi)}_{\text{prior latent variables}} \ \times \ \underbrace{p(\mu \mid \Lambda)\ p(\Lambda)\ p(\phi)}_{\text{prior parameters}}$$

with the likelihood of observations $X$ being a Gaussian raised to the power of the latent assignment variables $z$

$$ p(X \mid z, \mu, \Sigma) = \prod_{i=1}^{N} \prod_{k=1}^{K} \mathcal{N}(x_i \mid \mu_k, \Sigma_k)^{z_i = k}$$

the prior for each latent variable $z_i$ being a Categorical distribution

$$ p(z_i \mid \phi) = \text{Categorical}(z_i \mid \phi) $$

and priors for the parameters being

$$ \begin{align*} p(\mu_k \mid \Lambda_k) =&\ \mathcal{N}(\mu_k \mid m_0, l_0^{-1}\Lambda_k^{-1}) \qquad &\text{for all}\ k \\ p(\Lambda_k) =&\ \text{Wishart}(\Lambda_k \mid W_0, v_0) \qquad &\text{for all}\ k \\ p(\phi) =&\ \text{Dirichlet}(\phi \mid a_0) \, , \end{align*}$$

We will be implementing this model directly in ForneyLab. If you're unfamiliar with these distributions or with the Gaussian mixture model, have another look at Bert's lectures.


First, we will do a bit of bookkeeping.

In [5]:
# Data dimensionality
num_features = size(dataset["data"],2)

# Sample size
num_samples = size(dataset["data"],1)

# Number of mixture components
num_components = 3;

I am going to need a identity matrix a couple of times, so I'm setting up a convenience variable.

In [6]:
Id = Matrix{Float64}(I, num_features, num_features);

Mixture models can be sensitive to initialization, so we are going to specify the prior means explicitly.

In [7]:
# Prior parameters
m0 = [ 1.0 0.0 -1.0;
      -1.0 0.0  1.0];

Now to start the factor graph. This is another iid setting, which means the graph will be too large to visualize once it's complete.

In [8]:
# Initialize vector variables
z = Vector{Variable}(undef, num_samples)
X = Vector{Variable}(undef, num_samples)
Λ = Vector{Variable}(undef, num_components)
μ = Vector{Variable}(undef, num_components)

# Start a graph
graph1 = FactorGraph()

# Mixture weights are drawn from a Dirichlet distribution
@RV ϕ ~ Dirichlet(ones(num_components))

θ = []
for k = 1:num_components
    
    # Parameters of k-th component
    @RV Λ[k] ~ Wishart(Id, num_features)
    @RV μ[k] ~ GaussianMeanPrecision(m0[:,k], Λ[k])
    
    push!(θ, μ[k], Λ[k])
end

for i = 1:num_samples
    
    # Assignment variable
    @RV z[i] ~ Categorical(ϕ)
    
    # Gaussian mixture component
    @RV X[i] ~ GaussianMixture(z[i], θ...)
    
    # Add data 
    placeholder(X[i], :X, dims=(num_features,), index=i)
end

The next step is to compile an inference algorithm.

In [9]:
# Specify recognition factorization (mean-field)
q = PosteriorFactorization(ϕ, μ[1], Λ[1], μ[2], Λ[2], μ[3], Λ[3], z, 
                           ids=[:ϕ, :μ_1, :Λ_1, :μ_2, :Λ_2, :μ_3, :Λ_3, :z])

# Generate the algorithm
algorithm = messagePassingAlgorithm(free_energy=true)
source_code = algorithmSourceCode(algorithm, free_energy=true);
eval(Meta.parse(source_code));

After that, we feed in data, initialize recognition factors and run the inference procedure.

In [10]:
# Convert data to a format suited to ForneyLab
observations = [dataset["data"][i,:] for i in 1:num_samples]

# Add to data dictionary
data = Dict(:X => observations)

# Prepare recognition distributions
marginals = Dict()
marginals[:ϕ] = ProbabilityDistribution(Dirichlet, a=ones(num_components,))
for k = 1:num_components
    marginals[:μ_*k] = ProbabilityDistribution(Multivariate, GaussianMeanPrecision, m=m0[:,k], w=Id)
    marginals[:Λ_*k] = ProbabilityDistribution(Wishart, v=Id, nu=num_features)
end
for i = 1:num_samples
    marginals[:z_*i] = ProbabilityDistribution(Categorical, p=ones(num_components,)./num_components)
end

# Number of iterations
num_iterations = 20

# Preallocate free energy tracking array
F = Float64[]

# Execute algorithm
@showprogress for i = 1:num_iterations
    
    # Update assignments
    stepz!(data, marginals)
    
    # Update parameters
    stepϕ!(data, marginals)
    stepμ_1!(data, marginals)
    stepΛ_1!(data, marginals)
    stepμ_2!(data, marginals)
    stepΛ_2!(data, marginals)
    stepμ_3!(data, marginals)
    stepΛ_3!(data, marginals)
        
    # Store variational free energy for visualization
    push!(F, freeEnergy(data, marginals))
end
Progress: 100%|█████████████████████████████████████████| Time: 0:00:23

Alright, we're done. Let's track the evolution of free energy.

In [11]:
# Plot free energy to check for convergence
plot(1:num_iterations, F, color="black", label="")
xlabel!("Number of iterations")
ylabel!("Free Energy")
Out[11]:

That looks like it is nicely decreasing. We might want to increase the number of iterations a bit more.

Let's now visualize the cluster on top of the observations.

In [12]:
# Estimated means (unsafeMean retrieves parameters directly)
μ1_estimated = unsafeMean(marginals[:μ_1])
μ2_estimated = unsafeMean(marginals[:μ_2])
μ3_estimated = unsafeMean(marginals[:μ_3])

# Estimated precisions
Λ1_estimated = unsafeMean(marginals[:Λ_1])
Λ2_estimated = unsafeMean(marginals[:Λ_2])
Λ3_estimated = unsafeMean(marginals[:Λ_3])

# Invert to covariances
Σ1_estimated = inv(Λ1_estimated)
Σ2_estimated = inv(Λ2_estimated)
Σ3_estimated = inv(Λ3_estimated)

# Select dimensions to plot
dims_plot = [1, 2]
dim_limsx = [minimum(dataset["data"][:,dims_plot[1]])-1, maximum(dataset["data"][:,dims_plot[1]])+1]
dim_limsy = [minimum(dataset["data"][:,dims_plot[2]])-1, maximum(dataset["data"][:,dims_plot[2]])+1]

# Plot data and overlay estimated posterior probabilities
plot_clusters(dataset["data"][:, dims_plot], 
              μ=[μ1_estimated[dims_plot], μ2_estimated[dims_plot], μ3_estimated[dims_plot]], 
              Σ=[Σ1_estimated[dims_plot,dims_plot], Σ2_estimated[dims_plot,dims_plot], Σ3_estimated[dims_plot,dims_plot]], 
              x1=range(dim_limsx[1], step=0.01, stop=dim_limsx[2]), 
              x2=range(dim_limsy[1], step=0.01, stop=dim_limsy[2]), 
              size=(600,400))
Out[12]:

That doesn't look bad. The three Gaussians nicely cover all samples.


$\ast$ Try for yourself

Play around with the number of components. Can you get an equally good coverage with just 2 components? What if you had 4?


Problem: Alpine Railways

The Swiss Federal Railways company operates a series of mountain railways bringing hikers (in the summer) and skiers (in the winter) up the Alps. They are setting up a new fallback security system where they intend to track trains through cameras and remote sensors. They want you to design a system to keep track of the trains' positions.

Data

A train going uphill updates its position according to: new position = old position + velocity x length of time-step + noise. The noise represents the train randomly slipping and sliding back down. We observe the train through a remote sensor, producing noisy observations of its position.

You receive a data set with past recordings. Your job is to set up an online filtering system, which can be deployed later on to process the incoming signal in real-time.

In [13]:
signal = load("../datasets/alpinerails_filtering.jld");

# Unpack data
states = signal["X"]
observations = signal["Y"]
transition = signal["A"]
emission = signal["C"]
process_noise = signal["Q"]
measurement_noise = signal["R"]
T = signal["T"]
Δt = signal["Δt"];

# Size
M = size(states,1)
N = size(observations,1)

# Visualize
plot(1:T, states[1,:], color="red", label="states", grid=false, xlabel="time (t)", ylabel="train position")
scatter!(1:T, observations[1,:], color="black", label="observations")
Out[13]:

Model specification

We are going to use a Kalman filter to track the train's position. To do this, we define noisy observations $y_k$ with latent states $x_k$. Observations are generated through a Gaussian likelihood centered on an emission matrix $C$ times the current state perturbed by measurement noise precision $R$. State transitions follow a Gaussian distribution centered on a transition matrix $A$ times the previous state perturbed by process noise precision $Q$. In equation form, these are:

$$\begin{align} p(x_k \mid x_{k-1}) =&\ \mathcal{N}(x_k \mid A x_{k-1}, Q)\\ p(y_k \mid x_k) =&\ \mathcal{N}(y_k \mid C x_k, R) \, . \end{align}$$

We have a prior for the previous state $x_{k-1} \sim \mathcal{N}(m_{k-1}, V_{k-1})$. In filtering problems, we feed the estimates of the current states as the parameters for the previous state in the next time-step.

In this case, we have knowledge of the transition and emission matrices as well as the noises. In general, you don't know these exactly. But the Bayesian framework allows you to estimate them simultaneously. With ForneyLab, you often only need to define additional priors. For the noise sources, you would need Wishart distributions. The emission and transition matrices are a bit more difficult, unless the matrices are structured in some convenient way (e.g. autoregression).

In [14]:
# Initialize a graph
graph2 = FactorGraph()

# Define initial state prior
@RV x_kmin1 ~ GaussianMeanVariance(placeholder(:m_kmin1, dims=(M,)), 
                                   placeholder(:V_kmin1, dims=(M,M)))
    
# State transition
@RV x_k ~ GaussianMeanVariance(transition * x_kmin1, process_noise)
    
# Observation likelihood
@RV y_k ~ GaussianMeanVariance(dot([1., 0.], x_k), measurement_noise)
    
# Tell FL that y is observed
placeholder(y_k, :y_k);

# Visualize subgraph
ForneyLab.draw(graph2)
G 13267266534744674826 clamp_1 17000857370237763332 dot dotproduct_1 17185112300538482456 𝒩 gaussianmeanvariance_2 17000857370237763332--17185112300538482456 x_k 1 out 2 in1 8050669281895332215 clamp_3 17000857370237763332--8050669281895332215 clamp_3 1 out 3 in2 16204854925374028525 × multiplication_1 17185112300538482456--16204854925374028525 variable_1 1 out 2 m 5619240914561591808 clamp_2 17185112300538482456--5619240914561591808 clamp_2 1 out 3 v 16204854925374028525--13267266534744674826 clamp_1 1 out 3 a 7710899787540096175 𝒩 gaussianmeanvariance_1 16204854925374028525--7710899787540096175 x_kmin1 1 out 2 in1 16652916604226674399 clamp_4 13970360195484169481 placeholder_m_kmin1 7710899787540096175--13970360195484169481 m_kmin1 1 out 2 m 14543796047685428570 placeholder_V_kmin1 7710899787540096175--14543796047685428570 V_kmin1 1 out 3 v 5578427515206358624 𝒩 gaussianmeanvariance_3 5578427515206358624--17000857370237763332 variable_2 1 out 2 m 5578427515206358624--16652916604226674399 clamp_4 1 out 3 v 1821647178468482117 placeholder_y_k 1821647178468482117--5578427515206358624 y_k 1 out 1 out

We'll now construct the algorithm and infer results.

In [15]:
# Generate inference algorithm
algorithm = messagePassingAlgorithm(x_k)
source_code = algorithmSourceCode(algorithm)
eval(Meta.parse(source_code));

For filtering, we use the same graph in each time-step.

In [16]:
# Initialize recognition distribution marginals
marginals = Dict(:x_k => vague(GaussianMeanVariance, M))

# Initialize message array
messages = Array{Message}(undef, 5)

# Keep track of estimates
m_x = zeros(M,T+1)
V_x = repeat(Matrix{Float64}(I,M,M), outer=(1,1,T+1))

@showprogress for k = 1:T
    
    # Initialize data
    data = Dict(:y_k => observations[k],
                :m_kmin1 => m_x[:,k],
                :V_kmin1 => V_x[:,:,k])
    
    # Update states
    step!(data, marginals, messages)
    
    # Store estimates
    m_x[:,k+1] = mean(marginals[:x_k])
    V_x[:,:,k+1] = cov(marginals[:x_k])
    
end
Progress: 100%|█████████████████████████████████████████| Time: 0:00:02

Let's check whether everything went ok. We'll visualize the state estimations.

In [17]:
# Visualization
plot(1:T, states[1,:], color="red", label="states", grid=false)
plot!(1:T, m_x[1,2:end], ribbon=[sqrt.(V_x[1,1,2:end]), sqrt.(V_x[1,1,2:end])], fillalpha=0.2, color="blue", label="inferred")
scatter!(1:T, observations[1,:], color="black", label="observations")
Out[17]:

We're going to inspect some messages. Let's open up the algorithm and look up the marginal computation for the final $x_k$. It will be the multiplication of two messages, one consisting of the state transition prediction and the other consisting of the measurement likelihood.

In [18]:
println(source_code)
begin

function step!(data::Dict, marginals::Dict=Dict(), messages::Vector{Message}=Array{Message}(undef, 5))

messages[1] = ruleSPGaussianMeanVarianceOutNPP(nothing, Message(Multivariate, PointMass, m=data[:m_kmin1]), Message(MatrixVariate, PointMass, m=data[:V_kmin1]))
messages[2] = ruleSPMultiplicationOutNGP(nothing, messages[1], Message(MatrixVariate, PointMass, m=[1.0 0.1; 0.0 1.0]))
messages[3] = ruleSPGaussianMeanVarianceOutNGP(nothing, messages[2], Message(MatrixVariate, PointMass, m=[0.01 0.0; 0.0 0.1]))
messages[4] = ruleSPGaussianMeanVarianceMPNP(Message(Univariate, PointMass, m=data[:y_k]), nothing, Message(Univariate, PointMass, m=0.1))
messages[5] = ruleSPDotProductIn1GNP(messages[4], nothing, Message(Multivariate, PointMass, m=[1.0, 0.0]))

marginals[:x_k] = messages[3].dist * messages[5].dist

return marginals

end

end # block

Alright, we need messages 3 and 5. Let's visualize them along with the state marginal.

In [19]:
# Collect state transition prediction
m_statepred = convert(ProbabilityDistribution{Multivariate, GaussianMeanVariance}, messages[3].dist)

# Collect observation likelihood
m_likelihood = convert(ProbabilityDistribution{Multivariate, GaussianMeanVariance}, messages[5].dist)

# Collect corrected prediction
state_marginal = convert(ProbabilityDistribution{Multivariate, GaussianMeanVariance}, marginals[:x_k])

# # Extract x-coordinates
m_statepred_x = ProbabilityDistribution(Univariate, GaussianMeanVariance, m=m_statepred.params[:m][1], v=m_statepred.params[:v][1,1])
m_likelihood_x = ProbabilityDistribution(Univariate, GaussianMeanVariance, m=m_likelihood.params[:m][1], v=m_likelihood.params[:v][1,1])
state_marginal_x = ProbabilityDistribution(Univariate, GaussianMeanVariance, m=state_marginal.params[:m][1], v=state_marginal.params[:v][1,1])

# Plot of the prediction, noisy measurement, and corrected prediction for x-coordinate
plot_messages(m_statepred_x, m_likelihood_x, state_marginal_x)
Out[19]:

As you can see, the state estimate is a combination of the state prediction, produced by the message from the state transition node, and the observation likelihood, produced by the message from the likelihood.


$\ast$ Try for yourself

Re-run the inference procedure and stop at an earlier time-step, for example $k$=2. How does the balance between the state prediction and the observation likelihood differ?