using Plots
using HTTP, CSV
using DataFrames: DataFrame
using AugmentedGaussianProcesses
using MLDataUtils
data = HTTP.get("https://www.openml.org/data/get_csv/1586217/phpwRjVjk")
data = CSV.read(data.body, DataFrame)
data.Class[data.Class .== 2] .= -1
data = Matrix(data)
X = data[:, 1:2]
Y = Int.(data[:, end]);
(X_train, y_train), (X_test, y_test) = splitobs((X, Y), 0.5, ObsDim.First()) # We split the data into train and test set
(([1.14 -0.114; -1.52 -1.15; … ; -0.0708 0.439; 0.177 -1.37], [1, -1, 1, -1, -1, -1, 1, -1, -1, -1 … 1, -1, 1, -1, -1, -1, 1, -1, -1, 1]), ([-0.135 0.136; -0.288 0.385; … ; 0.769 0.772; -0.255 -0.142], [-1, -1, -1, 1, 1, -1, -1, 1, -1, 1 … -1, -1, -1, -1, -1, -1, -1, -1, 1, -1]))
function plot_data(X, Y; size=(300, 500))
return Plots.scatter(
eachcol(X)...; xlabel="x", ylabel="y", group=Y, alpha=0.2, markerstrokewidth=0.0, lab="", size=size
)
end
plot_data(X, Y; size=(500, 800))
Using Gaussian processes to solve binary classification problem is usually defined as y∼Bernoulli(h(f))
h
is the inverse link.
Multiple choices exist for h
but we will focus mostly on h(x)=\sigma(x)=(1+\exp(-x))^{-1}
, i.e. the logistic function.
Ms = [4, 8, 16, 32, 64] # Number of inducing points
models = Vector{AbstractGPModel}(undef, length(Ms) + 1)
kernel = with_lengthscale(SqExponentialKernel(), 1.0) # We create a standard kernel with lengthscale 1
for (i, num_inducing) in enumerate(Ms)
@info "Training with $(num_inducing) points"
m = SVGP(
kernel,
LogisticLikelihood(),
AnalyticVI(),
inducingpoints(KmeansAlg(num_inducing), X); # Z is selected via the kmeans algorithm
optimiser=false, # We keep the kernel parameters fixed
Zoptimiser=false, # We keep the inducing points locations fixed
)
@time train!(m, X_train, y_train, 5) # We train the model on the training data for 5 iterations
models[i] = m # And store the model
end
[ Info: Training with 4 points 7.881165 seconds (20.69 M allocations: 1.112 GiB, 4.90% gc time, 99.74% compilation time) [ Info: Training with 8 points 0.002279 seconds (869 allocations: 4.319 MiB) [ Info: Training with 16 points 0.006113 seconds (893 allocations: 7.668 MiB) [ Info: Training with 32 points 0.012649 seconds (941 allocations: 14.532 MiB) [ Info: Training with 64 points 0.032674 seconds (1.10 k allocations: 28.946 MiB)
@info "Running full model"
mfull = VGP(X_train, y_train, kernel, LogisticLikelihood(), AnalyticVI(); optimiser=false)
@time train!(mfull, 5)
models[end] = mfull
[ Info: Running full model 18.819545 seconds (3.73 M allocations: 2.459 GiB, 1.88% gc time, 8.94% compilation time)
Variational Gaussian Process with a BernoulliLikelihood{GPLikelihoods.LogisticLink}(GPLikelihoods.LogisticLink(LogExpFunctions.logistic)) infered by Analytic Variational Inference
function compute_grid(model, n_grid=50)
mins = [-3.25, -2.85]
maxs = [3.65, 3.4]
x_lin = range(mins[1], maxs[1]; length=n_grid)
y_lin = range(mins[2], maxs[2]; length=n_grid)
x_grid = Iterators.product(x_lin, y_lin)
y_grid, _ = proba_y(model, vec(collect.(x_grid)))
return y_grid, x_lin, y_lin
end
function plot_model(model, X, Y, title=nothing; size=(300, 500))
n_grid = 50
y_pred, x_lin, y_lin = compute_grid(model, n_grid)
title = if isnothing(title)
(model isa SVGP ? "M = $(AGP.dim(model[1]))" : "full")
else
title
end
p = plot_data(X, Y; size=size)
Plots.contour!(
p,
x_lin,
y_lin,
reshape(y_pred, n_grid, n_grid)';
cbar=false,
levels=[0.5],
fill=false,
color=:black,
linewidth=2.0,
title=title,
)
if model isa SVGP
Plots.scatter!(
p, eachrow(hcat(AGP.Zview(model[1])...))...; msize=2.0, color="black", lab=""
)
end
return p
end;
Plots.plot(
plot_model.(models, Ref(X), Ref(Y))...; layout=(1, length(models)), size=(1000, 200)
)
@info "Running model with Bayesian SVM Likelihood"
mbsvm = VGP(X_train, y_train, kernel, BayesianSVM(), AnalyticVI(); optimiser=false)
@time train!(mbsvm, 5)
[ Info: Running model with Bayesian SVM Likelihood 18.078650 seconds (1.61 M allocations: 2.344 GiB, 3.07% gc time, 4.05% compilation time)
(Variational Gaussian Process with a BernoulliLikelihood{AugmentedGaussianProcesses.SVMLink}(AugmentedGaussianProcesses.SVMLink()) infered by Analytic Variational Inference , (local_vars = (c = [0.19066459123745133, 0.23164540665694638, 0.003154447115920155, 0.13783608432003847, 0.23493768647470348, 4.754524325432238, 1.6762565727015175, 1.4822238505116838, 0.4547369073065402, 0.07199295745857563 … 0.36745423151625917, 0.46906021066829073, 0.10311036630835056, 0.8201390461860333, 3.2678129506884424, 0.679356962296097, 5.001778909682685, 0.4947182427853843, 1.2812846432133913, 2.956175231330885], θ = [2.2901555264125135, 2.0777254216310244, 17.804852267914043, 2.69350965394914, 2.0631160426393187, 0.458613107625561, 0.7723777475440399, 0.821378066129373, 1.4829274286914775, 3.726962239776205 … 1.649674926216585, 1.4601104286331292, 3.114216742000558, 1.104221644102182, 0.5531862862465108, 1.21325191286659, 0.44713406146238294, 1.4217427992508633, 0.883440265183066, 0.5816140819612164]), opt_state = (NamedTuple(),), hyperopt_state = (NamedTuple(),), kernel_matrices = ((K = LinearAlgebra.Cholesky{Float64, Matrix{Float64}}([1.0000499987500624 0.017000746952647305 … 0.4123128729363635 0.2857887156916332; 0.017001596968745064 0.9999054828347788 … 0.09200744609514472 0.22644779606578086; … ; 0.4123334880646449 0.0990083766302707 … 0.010077146944510723 -1.6990382050684463e-6; 0.2858030047701997 0.23128501449942165 … 0.18882343252913514 0.010065643970442203], 'U', 0),),)))
Plots.plot(
plot_model.(
[models[end], mbsvm], Ref(X), Ref(Y), ["Logistic", "BSVM"]; size=(500, 250)
)...;
layout=(1, 2),
)
This notebook was generated using Literate.jl.