using Pkg; Pkg.activate("../."); Pkg.instantiate(); using IJulia; try IJulia.clear_output(); catch _ end using Plots, Distributions N = 250; p_apple = 0.7; Σ = [0.2 0.1; 0.1 0.3] p_given_apple = MvNormal([1.0, 1.0], Σ) # p(X|y=apple) p_given_peach = MvNormal([1.7, 2.5], Σ) # p(X|y=peach) X = Matrix{Float64}(undef,2,N); y = Vector{Bool}(undef,N) # true corresponds to apple for n=1:N y[n] = (rand() < p_apple) # Apple or peach? X[:,n] = y[n] ? rand(p_given_apple) : rand(p_given_peach) # Sample features end X_apples = X[:,findall(y)]'; X_peaches = X[:,findall(.!y)]' # Sort features on class x_test = [2.3; 1.5] # Features of 'new' data point scatter(X_apples[:,1], X_apples[:,2], label="apples", marker=:x, markerstrokewidth=3) # apples scatter!(X_peaches[:,1], X_peaches[:,2], label="peaches", marker=:+, markerstrokewidth=3) # peaches scatter!([x_test[1]], [x_test[2]], label="unknown") # 'new' unlabelled data point # Make sure you run the data-generating code cell first using Distributions, Plots # Multinomial (in this case binomial) density estimation p_apple_est = sum(y.==true) / length(y) π_hat = [p_apple_est; 1-p_apple_est] # Estimate class-conditional multivariate Gaussian densities d1 = fit_mle(FullNormal, X_apples') # MLE density estimation d1 = N(μ₁, Σ₁) d2 = fit_mle(FullNormal, X_peaches') # MLE density estimation d2 = N(μ₂, Σ₂) Σ = π_hat[1]*cov(d1) + π_hat[2]*cov(d2) # Combine Σ₁ and Σ₂ into Σ conditionals = [MvNormal(mean(d1), Σ); MvNormal(mean(d2), Σ)] # p(x|C) # Calculate posterior class probability of x∙ (prediction) function predict_class(k, X) # calculate p(Ck|X) norm = π_hat[1]*pdf(conditionals[1],X) + π_hat[2]*pdf(conditionals[2],X) return π_hat[k]*pdf(conditionals[k], X) ./ norm end println("p(apple|x=x∙) = $(predict_class(1,x_test))") # Discrimination boundary of the posterior (p(apple|x;D) = p(peach|x;D) = 0.5) β(k) = inv(Σ)*mean(conditionals[k]) γ(k) = -0.5 * mean(conditionals[k])' * inv(Σ) * mean(conditionals[k]) + log(π_hat[k]) function discriminant_x2(x1) # Solve discriminant equation for x2 β12 = β(1) .- β(2) γ12 = (γ(1) .- γ(2))[1,1] return -1*(β12[1]*x1 .+ γ12) ./ β12[2] end scatter(X_apples[:,1], X_apples[:,2], label="apples", marker=:x, markerstrokewidth=3) # apples scatter!(X_peaches[:,1], X_peaches[:,2], label="peaches", marker=:+, markerstrokewidth=3) # peaches scatter!([x_test[1]], [x_test[2]], label="unknown") # 'new' unlabelled data point x1 = range(-1,length=10,stop=3) plot!(x1, discriminant_x2(x1), color="black", label="") # Plot discrimination boundary plot!(x1, discriminant_x2(x1), fillrange=-10, alpha=0.2, color=:blue, label="") plot!(x1, discriminant_x2(x1), fillrange=10, alpha=0.2, color=:red, xlims=(-0.5, 3), ylims=(-1, 4), label="") open("../../styles/aipstyle.html") do f display("text/html", read(f,String)) end