# Generate dataset {(x1,y1),...,(xN,yN)} # x is a 2-d feature vector [x_1;x_2] # y ∈ {false,true} is a binary class label # p(x|y) is multi-modal (mixture of uniform and Gaussian distributions) using PyPlot include("scripts/lesson8_helpers.jl") N = 200 X, y = genDataset(N) # Generate data set, collect in matrix X and vector y X_c1 = X[:,findall(.!y)]'; X_c2 = X[:,findall(y)]' # Split X based on class label X_test = [3.75; 1.0] # Features of 'new' data point function plotDataSet() plot(X_c1[:,1], X_c1[:,2], "bx", markersize=8) plot(X_c2[:,1], X_c2[:,2], "r+", markersize=8, fillstyle="none") plot(X_test[1], X_test[2], "ko") xlabel(L"x_1"); ylabel(L"x_2"); legend([L"y=0", L"y=1",L"y=?"], loc=2) xlim([-2;10]); ylim([-4, 8]) end plotDataSet(); using Optim # Optimization library y_1 = zeros(length(y))# class 1 indicator vector y_1[findall(y)] .= 1 X_ext = vcat(X, ones(1, length(y))) # Extend X with a row of ones to allow an offset in the discrimination boundary # Implement negative log-likelihood function function negative_log_likelihood(θ::Vector) # Return negative log-likelihood: -L(θ) p_1 = 1.0 ./ (1.0 .+ exp.(-X_ext' * θ)) # P(C1|X,θ) return -sum(log.( (y_1 .* p_1) + ((1 .- y_1).*(1 .- p_1))) ) # negative log-likelihood end # Use Optim.jl optimiser to minimize the negative log-likelihood function w.r.t. θ results = optimize(negative_log_likelihood, zeros(3), LBFGS()) θ = results.minimizer # Plot the data set and ML discrimination boundary plotDataSet() p_1(x) = 1.0 ./ (1.0 .+ exp(-([x;1.]' * θ))) boundary(x1) = -1 ./ θ[2] * (θ[1]*x1 .+ θ[3]) plot([-2.;10.], boundary([-2.; 10.]), "k-"); # # Also fit the generative Gaussian model from lesson 7 and plot the resulting discrimination boundary for comparison generative_boundary = buildGenerativeDiscriminationBoundary(X, y) plot([-2.;10.], generative_boundary([-2;10]), "k:"); legend([L"y=0";L"y=1";L"y=?";"Discr. boundary";"Gen. boundary"], loc=3); x_test = [3.75;1.0] println("P(C1|x•,θ) = $(p_1(x_test))") open("../../styles/aipstyle.html") do f display("text/html", read(f,String)) end