using Plots
using HTTP, CSV
using DataFrames: DataFrame
using AugmentedGaussianProcesses
data = HTTP.get("https://www.openml.org/data/get_csv/1586217/phpwRjVjk")
data = CSV.read(data.body, DataFrame)
data.Class[data.Class .== 2] .= -1
data = Matrix(data)
X = data[:, 1:2]
Y = data[:, end];
function plot_data(X, Y; size=(300, 500))
return Plots.scatter(
eachcol(X)...; group=Y, alpha=0.2, markerstrokewidth=0.0, lab="", size=size
)
end
plot_data(X, Y; size=(500, 500))
Ms = [4, 8, 16, 32, 64]
models = Vector{AbstractGP}(undef, length(Ms) + 1)
kernel = SqExponentialKernel() ∘ ScaleTransform(1.0)
for (i, num_inducing) in enumerate(Ms)
@info "Training with $(num_inducing) points"
m = SVGP(
X,
Y,
kernel,
LogisticLikelihood(),
AnalyticVI(),
num_inducing;
optimiser=false,
Zoptimiser=false,
)
@time train!(m, 20)
models[i] = m
end
[ Info: Training with 4 points 9.287114 seconds (18.72 M allocations: 1.119 GiB, 4.98% gc time, 99.19% compilation time) [ Info: Training with 8 points 0.024066 seconds (23.32 k allocations: 31.331 MiB) [ Info: Training with 16 points 0.081005 seconds (23.32 k allocations: 54.341 MiB, 53.06% gc time) [ Info: Training with 32 points 0.055143 seconds (23.32 k allocations: 100.906 MiB) [ Info: Training with 64 points 0.124772 seconds (23.51 k allocations: 196.233 MiB, 18.99% gc time)
@info "Running full model"
mfull = VGP(X, Y, kernel, LogisticLikelihood(), AnalyticVI(); optimiser=false)
@time train!(mfull, 5)
models[end] = mfull
[ Info: Running full model 31.872851 seconds (3.50 M allocations: 8.172 GiB, 3.05% gc time, 5.52% compilation time)
Variational Gaussian Process with a Bernoulli Likelihood with Logistic Link infered by Analytic Variational Inference
function compute_grid(model, n_grid=50)
mins = [-3.25, -2.85]
maxs = [3.65, 3.4]
x_lin = range(mins[1], maxs[1]; length=n_grid)
y_lin = range(mins[2], maxs[2]; length=n_grid)
x_grid = Iterators.product(x_lin, y_lin)
y_grid, _ = proba_y(model, vec(collect.(x_grid)))
return y_grid, x_lin, y_lin
end
function plot_model(model, X, Y, title=nothing; size=(300, 500))
n_grid = 50
y_pred, x_lin, y_lin = compute_grid(model, n_grid)
title = if isnothing(title)
(model isa SVGP ? "M = $(AGP.dim(model[1]))" : "full")
else
title
end
p = plot_data(X, Y; size=size)
Plots.contour!(
p,
x_lin,
y_lin,
reshape(y_pred, n_grid, n_grid)';
cbar=false,
levels=[0.5],
fill=false,
color=:black,
linewidth=2.0,
title=title,
)
if model isa SVGP
Plots.scatter!(
p, eachrow(hcat(AGP.Zview(model[1])...))...; msize=2.0, color="black", lab=""
)
end
return p
end;
Plots.plot(
plot_model.(models, Ref(X), Ref(Y))...; layout=(1, length(models)), size=(1000, 200)
)
mbsvm = VGP(X, Y, kernel, BayesianSVM(), AnalyticVI(); optimiser=false)
@time train!(mbsvm, 5)
31.128042 seconds (1.91 M allocations: 8.074 GiB, 2.60% gc time, 2.33% compilation time)
Plots.plot(
plot_model.(
[models[end], mbsvm], Ref(X), Ref(Y), ["Logistic", "BSVM"]; size=(500, 500)
)...;
layout=(1, 2),
)
This notebook was generated using Literate.jl.