You are seeing the notebook output generated by Literate.jl from the Julia source file. The rendered HTML can be viewed in the docs.
# Background
This example trains a GP whose inputs are passed through a neural network.
This kind of model has been considered previously [^Calandra] [^Wilson], although it has been shown that some care is needed to avoid substantial overfitting [^Ober].
In this example we make use of the FunctionTransform
from KernelFunctions.jl to put a simple Multi-Layer Perceptron built using Flux.jl inside a standard kernel.
[^Calandra]: Calandra, R., Peters, J., Rasmussen, C. E., & Deisenroth, M. P. (2016, July). Manifold Gaussian processes for regression. In 2016 International Joint Conference on Neural Networks (IJCNN) (pp. 3338-3345). IEEE.
[^Wilson]: Wilson, A. G., Hu, Z., Salakhutdinov, R. R., & Xing, E. P. (2016). Stochastic variational deep kernel learning. Advances in Neural Information Processing Systems, 29.
[^Ober]: Ober, S. W., Rasmussen, C. E., & van der Wilk, M. (2021, December). The promises and pitfalls of deep kernel learning. In Uncertainty in Artificial Intelligence (pp. 1206-1216). PMLR.
We use a couple of useful packages to plot and optimize the different hyper-parameters
using AbstractGPs
using Distributions
using Flux
using KernelFunctions
using LinearAlgebra
using Plots
default(; legendfontsize=15.0, linewidth=3.0);
We create a simple 1D Problem with very different variations
xmin, xmax = (-3, 3) # Limits
N = 150
noise_std = 0.01
x_train_vec = rand(Uniform(xmin, xmax), N) # Training dataset
x_train = collect(eachrow(x_train_vec)) # vector-of-vectors for Flux compatibility
target_f(x) = sinc(abs(x)^abs(x)) # We use sinc with a highly varying value
y_train = target_f.(x_train_vec) + randn(N) * noise_std
x_test_vec = range(xmin, xmax; length=200) # Testing dataset
x_test = collect(eachrow(x_test_vec)) # vector-of-vectors for Flux compatibility
plot(xmin:0.01:xmax, target_f; label="ground truth")
scatter!(x_train_vec, y_train; label="training data")
We create a neural net with 2 layers and 10 units each. The data is passed through the NN before being used in the kernel.
neuralnet = Chain(Dense(1, 20), Dense(20, 30), Dense(30, 5))
Chain( Dense(1 => 20), # 40 parameters Dense(20 => 30), # 630 parameters Dense(30 => 5), # 155 parameters ) # Total: 6 arrays, 825 parameters, 3.527 KiB.
We use the Squared Exponential Kernel:
k = SqExponentialKernel() ∘ FunctionTransform(neuralnet)
Squared Exponential Kernel (metric = Distances.Euclidean(0.0)) - Function Transform: Chain(Dense(1 => 20), Dense(20 => 30), Dense(30 => 5))
We now define our model:
gpprior = GP(k) # GP Prior
fx = AbstractGPs.FiniteGP(gpprior, x_train, noise_std^2) # Prior at the observations
fp = posterior(fx, y_train) # Posterior of f given the observations
┌ Warning: Layer with Float32 parameters got Float64 input. │ The input will be converted, but any earlier layers may be very slow. │ layer = Dense(1 => 20) # 40 parameters │ summary(x) = "1-element view(::Matrix{Float64}, 1, :) with eltype Float64" └ @ Flux ~/.julia/packages/Flux/vwk6M/src/layers/stateless.jl:59
AbstractGPs.PosteriorGP{AbstractGPs.GP{AbstractGPs.ZeroMean{Float64}, KernelFunctions.TransformedKernel{KernelFunctions.SqExponentialKernel{Distances.Euclidean}, KernelFunctions.FunctionTransform{Flux.Chain{Tuple{Flux.Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Flux.Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Flux.Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}}}}, @NamedTuple{α::Vector{Float64}, C::LinearAlgebra.Cholesky{Float64, Matrix{Float64}}, x::Vector{SubArray{Float64, 1, Matrix{Float64}, Tuple{Int64, Base.Slice{Base.OneTo{Int64}}}, true}}, δ::Vector{Float64}}}(AbstractGPs.GP{AbstractGPs.ZeroMean{Float64}, KernelFunctions.TransformedKernel{KernelFunctions.SqExponentialKernel{Distances.Euclidean}, KernelFunctions.FunctionTransform{Flux.Chain{Tuple{Flux.Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Flux.Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}, Flux.Dense{typeof(identity), Matrix{Float32}, Vector{Float32}}}}}}}(AbstractGPs.ZeroMean{Float64}(), Squared Exponential Kernel (metric = Distances.Euclidean(0.0)) - Function Transform: Chain(Dense(1 => 20), Dense(20 => 30), Dense(30 => 5))), (α = [659.9484048091723, 278.4393566741794, -115.31366173528781, 880.9224858197226, -40.17451444291513, -137.8337462957908, -1155.9499234732884, -2699.669906510716, 719.0396297640852, 334.518888601976 … -305.8936870550849, -335.3136784072349, -1125.2354434334536, -1270.716764460443, 458.99697861771455, -1465.4123539251689, -1210.6221054795694, -1202.8362910417677, 271.7134024559203, 92.06311598288465], C = LinearAlgebra.Cholesky{Float64, Matrix{Float64}}([1.0000499987500624 0.07067853731082743 … 0.3950620952951509 0.34237994540124933; 0.07068207114934921 0.9975492691409292 … -0.026775996264110105 -0.02342430931492509; … ; 0.39508184790611267 0.0012120355386286974 … 0.010348142689642778 0.0006397650150331975; 0.3423970639705658 0.0008320111082866788 … 0.9948867559432983 0.010404437791179928], 'U', 0), x = SubArray{Float64, 1, Matrix{Float64}, Tuple{Int64, Base.Slice{Base.OneTo{Int64}}}, true}[[0.7288910497636443], [-2.290836990467076], [-2.461938089192837], [-0.7942055497327991], [-2.960481087124503], [2.5611036037463446], [-1.2314626045474824], [-0.04794157576475655], [0.5817075028377183], [0.44261603727225074] … [-2.3244535591100712], [-2.3338570212663634], [1.1717947174585879], [1.345618102777948], [0.7730331523012159], [-1.325358084901843], [-1.2750376992066932], [1.3618464434165896], [2.5166526094794577], [2.6494796146263013]], δ = [0.24409107724069948, 0.04920656959790945, -0.012501364220813837, 0.17837313897592227, 0.0006831740254951753, -0.011265552723176383, -0.18888511299904756, 0.14358794516460555, 0.33367951629303255, 0.3643120151933217 … -0.012400125204950829, -0.016459714485879746, -0.1605529227762436, -0.21584454243045698, 0.19767130277894568, -0.2307119087500064, -0.20028244658268024, -0.2106126839173522, 0.03605849798296579, -0.0003022202209281375]))
This computes the negative log evidence of y
(the negative log marginal likelihood of
the neural network parameters), which is going to be used as the objective:
loss(y) = -logpdf(fx, y)
@info "Initial loss = $(loss(y_train))"
┌ Warning: Layer with Float32 parameters got Float64 input. │ The input will be converted, but any earlier layers may be very slow. │ layer = Dense(1 => 20) # 40 parameters │ summary(x) = "1-element view(::Matrix{Float64}, 1, :) with eltype Float64" └ @ Flux ~/.julia/packages/Flux/vwk6M/src/layers/stateless.jl:59 [ Info: Initial loss = 5315.579171905199
Flux will automatically extract all the parameters of the kernel
ps = Flux.params(k)
Params([Float32[-0.45399225; 0.17021179; … ; -0.13894099; -0.2883374;;], Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], Float32[-0.24679856 0.20540601 … -0.047029946 -0.2996298; -0.30882016 0.054984786 … -0.2422496 -0.29434207; … ; -0.20213942 -0.0768872 … -0.28867817 -0.30518577; -0.10598487 -0.26754272 … -0.10966233 0.1014743], Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 … 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], Float32[-0.077661745 0.36521444 … -0.03137103 -0.05853775; 0.14965776 0.37962005 … 0.01209235 -0.40246785; … ; 0.32924226 -0.015455263 … -0.281077 -0.28713205; -0.3382412 0.24882494 … -0.25195035 -0.14829679], Float32[0.0, 0.0, 0.0, 0.0, 0.0]])
We show the initial prediction with the untrained model
p_init = plot(; title="Loss = $(round(loss(y_train); sigdigits=6))")
plot!(vcat(x_test...), target_f; label="true f")
scatter!(vcat(x_train...), y_train; label="data")
pred_init = marginals(fp(x_test))
plot!(vcat(x_test...), mean.(pred_init); ribbon=std.(pred_init), label="Prediction")
┌ Warning: Layer with Float32 parameters got Float64 input. │ The input will be converted, but any earlier layers may be very slow. │ layer = Dense(1 => 20) # 40 parameters │ summary(x) = "1-element view(::Matrix{Float64}, 1, :) with eltype Float64" └ @ Flux ~/.julia/packages/Flux/vwk6M/src/layers/stateless.jl:59
nmax = 200
opt = Flux.Adam(0.1)
anim = Animation()
for i in 1:nmax
grads = gradient(ps) do
loss(y_train)
end
Flux.Optimise.update!(opt, ps, grads)
if i % 10 == 0
L = loss(y_train)
@info "iteration $i/$nmax: loss = $L"
p = plot(; title="Loss[$i/$nmax] = $(round(L; sigdigits=6))")
plot!(vcat(x_test...), target_f; label="true f")
scatter!(vcat(x_train...), y_train; label="data")
pred = marginals(posterior(fx, y_train)(x_test))
plot!(vcat(x_test...), mean.(pred); ribbon=std.(pred), label="Prediction")
frame(anim)
display(p)
end
end
gif(anim, "train-dkl.gif"; fps=3)
nothing #hide
┌ Warning: Layer with Float32 parameters got Float64 input. │ The input will be converted, but any earlier layers may be very slow. │ layer = Dense(1 => 20) # 40 parameters │ summary(x) = "1-element view(::Matrix{Float64}, 1, :) with eltype Float64" └ @ Flux ~/.julia/packages/Flux/vwk6M/src/layers/stateless.jl:59 [ Info: iteration 10/200: loss = -97.6936903453906 [ Info: iteration 20/200: loss = -100.12899181912914 [ Info: iteration 30/200: loss = 88.40026099808755 [ Info: iteration 40/200: loss = 95.28903852113136 [ Info: iteration 50/200: loss = 92.90680370112422 [ Info: iteration 60/200: loss = 84.1158845850552 [ Info: iteration 70/200: loss = 54.08771116020551 [ Info: iteration 80/200: loss = -202.86644458704274 [ Info: iteration 90/200: loss = -25.885423893148072 [ Info: iteration 100/200: loss = 111.44187014494054 [ Info: iteration 110/200: loss = 117.37463811366253 [ Info: iteration 120/200: loss = 118.4626470313062 [ Info: iteration 130/200: loss = 118.28528068236982 [ Info: iteration 140/200: loss = 117.61221364324251 [ Info: iteration 150/200: loss = 116.66316785662872 [ Info: iteration 160/200: loss = 115.46816617384722 [ Info: iteration 170/200: loss = 113.97831359553602 [ Info: iteration 180/200: loss = 112.06472125916275 [ Info: iteration 190/200: loss = 109.48594170580829 [ Info: iteration 200/200: loss = 105.75383317373125 [ Info: Saved animation to /home/runner/work/AbstractGPs.jl/AbstractGPs.jl/docs/src/examples/2-deep-kernel-learning/train-dkl.gif
Status `~/work/AbstractGPs.jl/AbstractGPs.jl/examples/2-deep-kernel-learning/Project.toml` [99985d1d] AbstractGPs v0.5.23 `/home/runner/work/AbstractGPs.jl/AbstractGPs.jl#28319c5` [31c24e10] Distributions v0.25.118 ⌅ [587475ba] Flux v0.14.25 [ec8451be] KernelFunctions v0.10.65 [98b081ad] Literate v2.20.1 [cc2ba9b6] MLDataUtils v0.5.4 [91a5bcdd] Plots v1.40.10 ⌅ [e88e6eb3] Zygote v0.6.75 [37e2e46d] LinearAlgebra v1.11.0 Info Packages marked with ⌅ have new versions available but compatibility constraints restrict them from upgrading. To see why use `status --outdated`To reproduce this notebook's package environment, you can download the full Manifest.toml.
Julia Version 1.11.4 Commit 8561cc3d68d (2025-03-10 11:36 UTC) Build Info: Official https://julialang.org/ release Platform Info: OS: Linux (x86_64-linux-gnu) CPU: 4 × AMD EPYC 7763 64-Core Processor WORD_SIZE: 64 LLVM: libLLVM-16.0.6 (ORCJIT, znver3) Threads: 1 default, 0 interactive, 1 GC (on 4 virtual cores) Environment: JULIA_DEBUG = Documenter JULIA_LOAD_PATH = :/home/runner/.julia/packages/JuliaGPsDocs/7M86H/src
This notebook was generated using Literate.jl.