include("InfoGAN.jl"); InfoGAN.define_params("--help") args = "--epochs 200 --lr 1e-4 1e-4 1e-4 1e-4 --clip 0.01 --z 62 --c_SS 10 --c_cont 2 --dreps 5 -v --datadir MNIST" o = InfoGAN.define_params(args); display(o) ~ispath("../tests") && mkpath("../tests/") LOGFILE = "../tests/mnist-log.txt" MDLFILE = "../tests/mnist-model.jld2" PRINTFOLDER = "../tests/outputs" (xtrn,xtst,ytrn,ytst) = InfoGAN.loaddata(o; drop_p = 0.8) @info "Size of Datasets: " size(xtrn) size(xtst) size(ytrn) size(ytst) using Plots xtrn .-= minimum(xtrn; dims =(1,2)) xtrn ./= maximum(xtrn; dims=(1,2)) num2show = 10 idx = rand(1:size(xtrn,4), num2show) imgs = [xtrn[:,:,1,i]' for i = idx]; imgarr = hcat(imgs...) plot(Gray.(imgarr), size = (num2show*120, 120), axis=false) ? InfoGAN.defaultG using Knet bn(h, mode) = batchnorm(h, bnmoments(); training=mode) bn_relu(h, mode) = relu.(bn(h, mode)) prelu(x, p) = max.(p .* x, x) bn_prelu(h, mode, p) = prelu(bn(h, mode), p) # Leaky ReLU parameter α = 0.1; wFE = [randn(4,4,1,64) * sqrt(2/(1*4*4)), zeros(1,1,64,1), randn(4,4,64,128) * sqrt(2/(64*4*4)), zeros(1,1,128,1), randn(1024, 5*5*128) * sqrt(2/5*5*128), zeros(1024,1)] wFE = map(wi->convert(o[:atype], wi), wFE) function fe_fun(w, x; atype=Array{Float32}, mode=true) a = prelu(conv4(w[1], x; stride=2) .+ w[2], α) b = bn_prelu(conv4(w[3], a; stride=2) .+ w[4], mode, α) c = bn_prelu((w[5] * mat(b)) .+ w[6], mode, α) return c end wD = [randn(1, 1024) * sqrt(2/1024), zeros(1,1)] wD = map(wi->convert(o[:atype], wi), wD) function d_fun(w, fc) return (w[1] * fc) .+ w[2] end wQ = [randn(128, 1024) * sqrt(2/1024), zeros(128,1), randn(10, 128) * sqrt(2/128), zeros(10,1), randn(o[:c_cont]*2, 128) * sqrt(2/128), zeros(o[:c_cont]*2, 1)] wQ = map(wi->convert(o[:atype], wi), wQ) function q_fun(w, fc, o; mode=true) fc2 = bn_relu((w[1] * fc) .+ w[2], mode) logitsSS = [(w[3] * fc2) .+ wQ[4]] logits=nothing fc_cont = (w[5] * fc2) .+ w[6] mu = fc_cont[1:o[:c_cont],:] var = exp.(fc_cont[o[:c_cont]+1:end,:]) return logitsSS, logits, mu, var end wG = [randn(1024, 74) * sqrt(2/74), zeros(1024,1), randn(7*7*128, 1024) * sqrt(2/1024), zeros(7*7*128,1), randn(4,4,64,128) * sqrt(2/(128*4*4)), zeros(1,1,64,1), randn(3,3,64,64) * sqrt(2/(64*9)), zeros(1,1,64,1), randn(4,4,1,64) * sqrt(2/(64*4*4)), zeros(1,1,1,1), randn(3,3,1,1) * sqrt(2/(1*9)), zeros(1,1,1,1)] wG = map(wi->convert(o[:atype], wi), wG) function g_fun(w, Z; mode=true) Z= mat(Z) # #Linear Layers fc1 = bn((w[1] * Z) .+ w[2], mode) fc2 = bn_relu((w[3] * fc1) .+ w[4], mode) a = reshape(fc2, 7, 7, 128, :) b = deconv4(w[5], a; stride=2, padding=1) .+ w[6] b = bn_relu(conv4(w[7], b; padding=1) .+ w[8], mode) c = deconv4(w[9], b; stride=2, padding=1) .+ w[10] c = conv4(w[11], c; padding=1) .+ w[12] return c end dump(InfoGAN.FrontEnd) Fopt = map(x->Rmsprop(;lr=o[:lr][1]), wFE) F = InfoGAN.FrontEnd(wFE, o[:atype], fe_fun, Fopt); G = InfoGAN.Generator(wG, o, g_fun) D = InfoGAN.Discriminator(wD, o, d_fun) Q = InfoGAN.Auxiliary(wQ, o, q_fun) model = InfoGAN.InfoModel(F, D, G, Q, o); @time results = InfoGAN.train(xtrn, ytrn, xtst, ytst, model; mdlfile=MDLFILE, logfile=LOGFILE, printfolder=PRINTFOLDER); plot(results, xlabel = "Epoch Number", ylabel="Loss", label=["Discriminator", "Generator"]) function get_results(epoch_init, epoch_final, interval, img_dir) epochs = epoch_init:interval:epoch_final p, imgs = [], [] for e = epochs imgfile = string(img_dir, "/epoch", lpad(e, 3, "0"), ".png") img = convert(Array{Float32}, channelview(Gray.(load(imgfile))))[1:56, 1:56] push!(imgs, img) end return imgs, epochs end using Images IMGDIR = string(pwd(), "/tests/outputs"); imgs, epochs = get_results(10,100,10, IMGDIR); imgsarr = hcat(imgs...); size(imgsarr) h = 120 plot(Gray.(imgsarr), size = (length(epochs)*h, h), axis=false) preds = InfoGAN.get_c(xtst, model); using MLBase confusmat(10, convert(Array{Int64}, ytst[1,:]), preds[1,:])