import os
os.environ['THEANO_FLAGS']="device=gpu2"
%matplotlib inline
from matplotlib import pylab as plt
import theano
from theano import tensor as T
import numpy as np
#create a regular grid in weight space for visualisation
wmin = -5
wmax = 5
wrange = np.linspace(wmin,wmax,300)
w = np.repeat(wrange[:,None],300,axis=1)
w = np.concatenate([[w.flatten()],[w.T.flatten()]])
prior_variance = 2
logprior = -(w**2).sum(axis=0)/2/prior_variance
plt.contourf(wrange, wrange, logprior.reshape(300,300), cmap='gray');
plt.axis('square');
plt.xlim([wmin,wmax])
plt.ylim([wmin,wmax]);
#generating a toy dataset with three manually selected observations
from scipy.stats import logistic
sigmoid = logistic.cdf
logsigmoid = logistic.logcdf
def likelihood(w,x,b=0,y=1):
return logsigmoid(y*(np.dot(w.T,x) + b)).flatten()
x1 = np.array([[1.5],[1]])
x2 = np.array([[-1.5],[1]])
x3 = np.array([[0.5],[-1]])
y1=1
y2=1
y3=-1
llh1 = likelihood(w, x1, y=y1)
llh2 = likelihood(w, x2, y=y2)
llh3 = likelihood(w, x3, y=y3)
#calculating unnormalised log posterior
#this is only for illustration
logpost = llh1 + llh2 + llh3 + logprior
#plotting the real log posterior
#the red dots show the three datapoints, the small line segments shows the direction
#in which the corresponding label shifts the posterior. Positive datapoints shift the
# posterior away from zero in the direction of the datapoint, negative datapoints shift
# away from zero, in the opposite direction.
plt.contourf(wrange,
wrange,
np.exp(logpost.reshape(300,300).T),cmap='gray');
plt.axis('square');
plt.xlim([wmin,wmax])
plt.ylim([wmin,wmax])
plt.plot(x1[0],x1[1],'.r')
plt.plot([x1[0],x1[0]*(1+0.2*y1)],[x1[1],x1[1]*(1+0.2*y1)],'r-')
plt.plot(x2[0],x2[1],'.r')
plt.plot([x2[0],x2[0]*(1+0.2*y2)],[x2[1],x2[1]*(1+0.2*y2)],'r-')
plt.plot(x3[0],x3[1],'.r')
plt.plot([x3[0],x3[0]*(1+0.2*y3)],[x3[1],x3[1]*(1+0.2*y3)],'r-');
This part is for the actual GAN stuff. Here we define the generator and the discriminator networks in Lasagne, and code up the two loss functions in theano.
from lasagne.utils import floatX
from lasagne.layers import (
InputLayer,
DenseLayer,
NonlinearityLayer)
from lasagne.nonlinearities import sigmoid
#defines a 'generator' network
def build_G(input_var=None, num_z = 3):
network = InputLayer(input_var=input_var, shape=(None, num_z))
network = DenseLayer(incoming = network, num_units=10)
network = DenseLayer(incoming = network, num_units=20)
network = DenseLayer(incoming = network, num_units=2, nonlinearity=None)
return network
#defines the 'denoiser network'
def build_denoiser(input_var=None):
network = InputLayer(input_var=input_var, shape = (None, 2))
network = DenseLayer(incoming = network, num_units=20)
network = DenseLayer(incoming = network, num_units=10)
network = DenseLayer(incoming = network, num_units=20)
network = DenseLayer(incoming = network, num_units=2, nonlinearity=None)
return network
from lasagne.layers import get_output, get_all_params
from theano.printing import debugprint
from lasagne.updates import adam
from theano.tensor.shared_randomstreams import RandomStreams
from lasagne.objectives import squared_error
#variables for input (design matrix), output labels, GAN noise variable, weights
x_var = T.matrix('design matrix')
y_var = T.vector('labels')
z_var = T.matrix('GAN noise')
w_var = T.matrix('weights')
#theano variables for things like batchsize, learning rate, etc.
batchsize_var = T.scalar('batchsize', dtype='int32')
prior_variance_var = T.scalar('prior variance')
noise_variance_var = T.scalar('noise variance')
learningrate_var = T.scalar('learning rate')
#random numbers for sampling from the variational distribution
srng = RandomStreams(seed=1337)
z_rnd = srng.normal((batchsize_var,3))
epsilon_rnd = T.sqrt(noise_variance_var)*srng.normal((batchsize_var,2))
#instantiating the G and denoiser networks
generator = build_G(z_var)
denoiser = build_denoiser()
#these expressions are random samples from the variational distribution respectively
samples_from_generator = get_output(generator, z_rnd)
noisy_samples_from_generator = samples_from_generator + epsilon_rnd
#denoiser output for synthetic samples and noisy synthetic samples
denoised_noisy_samples_from_generator = get_output(denoiser, inputs=noisy_samples_from_generator)
denoised_samples_from_generator = get_output(denoiser, inputs=samples_from_generator)
#loss of discriminator - simple binary cross-entropy loss
loss_denoiser = squared_error(samples_from_generator, denoised_noisy_samples_from_generator).mean()
#log likelihood for each synthetic w sampled from the generator
log_likelihood = T.log(
T.nnet.sigmoid(
(y_var.dimshuffle(0,'x','x')*(x_var.dimshuffle(0,1,'x') * samples_from_generator.dimshuffle('x', 1, 0))).sum(1)
)
).sum(0).mean()
#log prior for synthetic w sampled from the generator
log_prior = -((samples_from_generator**2).sum(1)/2/prior_variance_var).mean()
params_G = get_all_params(generator, trainable=True)
#calculating the derivative of the entropy with respect to parameters of G, using theano's Lop
dHdG = (samples_from_generator - denoised_samples_from_generator)/noise_variance_var
dHdPhi = T.Lop(
f = samples_from_generator.flatten()/batchsize_var,
wrt = params_G,
eval_points=dHdG.flatten())
#calculating gradients of other terms in the bound and summing it all up
dLikelihooddPhi = T.grad(log_likelihood, wrt=params_G)
dPriordPhi = T.grad(log_prior,wrt=params_G)
dLdPhi = [-a-b-c for a,b,c in zip(dHdPhi,dPriordPhi,dLikelihooddPhi)]
updates_G = adam(
dLdPhi,
params_G,
learning_rate=learningrate_var,
)
#compiling theano functions:
evaluate_generator = theano.function(
[z_var],
get_output(generator),
allow_input_downcast=True
)
sample_generator = theano.function(
[batchsize_var],
samples_from_generator,
allow_input_downcast=True,
)
params_denoiser = get_all_params(denoiser, trainable=True)
updates_denoiser = adam(
loss_denoiser,
params_denoiser,
learning_rate = learningrate_var
)
train_denoiser = theano.function(
[batchsize_var, noise_variance_var, learningrate_var],
loss_denoiser,
updates = updates_denoiser,
allow_input_downcast = True
)
train_G = theano.function(
[x_var, y_var, prior_variance_var, noise_variance_var, batchsize_var, learningrate_var],
[],
updates = updates_G,
allow_input_downcast = True
)
grad_comp_1 = theano.clone(dHdG, replace={samples_from_generator: w_var})/batchsize
grad_comp_2 = T.grad(theano.clone(log_prior, replace={samples_from_generator: w_var}), wrt=w_var)
grad_comp_3 = T.grad(theano.clone(log_likelihood, replace={samples_from_generator: w_var}), wrt=w_var)
evaluate_gradients_G = theano.function(
[w_var, x_var, y_var, prior_variance_var, noise_variance_var],
[grad_comp_1 + grad_comp_2 + grad_comp_3, grad_comp_1, grad_comp_2, grad_comp_3],
allow_input_downcast = True
)
evaluate_denoiser = theano.function(
[w_var],
get_output(denoiser, w_var),
allow_input_downcast = True
)
#this is to evaluate the log-likelihood of an arbitrary set of w
llh_for_w = T.nnet.sigmoid((y_var.dimshuffle(0,'x','x')*(x_var.dimshuffle(0,1,'x') * w_var.dimshuffle('x', 1, 0))).sum(1))
evaluate_loglikelihood = theano.function(
[x_var, y_var, w_var],
llh_for_w,
allow_input_downcast = True
)
#checking that theano and numpy give the same likelihoods
import seaborn as sns
sns.set_context('poster')
X_train = np.concatenate([x1,x2,x3],axis=1).T
y_train = np.array([y1,y2,y3])
llh_theano = evaluate_loglikelihood(X_train, y_train, w.T)
plt.figure(figsize=(16,8))
plt.subplot(1,2,1)
plt.contourf(wrange, wrange ,np.log(llh_theano).sum(0).reshape(300,300).T,cmap='gray');
plt.axis('square');
plt.xlim([wmin,wmax])
plt.ylim([wmin,wmax])
plt.title('theano loglikelihood')
plt.subplot(1,2,2)
plt.contourf(wrange, wrange, (llh1+llh2+llh3).reshape(300,300).T,cmap='gray');
plt.axis('square');
plt.xlim([wmin,wmax])
plt.ylim([wmin,wmax])
plt.title('numpy loglikelihood')
assert np.allclose(llh1+llh2+llh3,np.log(llh_theano).sum(0))
#q and the denoiser before training:
wrange_spaced = np.linspace(wmin,wmax,30)
w_spaced = np.repeat(wrange_spaced[:,None],30,axis=1)
w_spaced = np.concatenate([[w_spaced.flatten()],[w_spaced.T.flatten()]])
arrows = evaluate_denoiser(w_spaced.T) - w_spaced.T
plt.quiver(w_spaced[0,:],w_spaced[1,:],arrows[:,0],arrows[:,1])
W = sample_generator(200)
plt.plot(W[:,0],W[:,1],'g.')
plt.axis('square')
plt.xlim([wmin,wmax])
plt.ylim([wmin,wmax]);
batchsize = 100
KARPATHY_CONSTANT = 3e-4
learning_rate = KARPATHY_CONSTANT*10
prior_variance = 1
#pre-training on large noise to avoid spurious modes
noise_variance = 0.7
# train discriminator for some time before starting iterative process
for i in range(5):
%timeit -n 200 train_denoiser(batchsize, noise_variance, learning_rate)
print(train_denoiser(batchsize*10, noise_variance, learning_rate))
noise_variance = 0.1
# train discriminator for some time before starting iterative process
for i in range(10):
%timeit -n 200 train_denoiser(batchsize, noise_variance, learning_rate)
print(train_denoiser(batchsize*10, noise_variance, learning_rate))
200 loops, best of 3: 1.39 ms per loop 0.08216796070337296 200 loops, best of 3: 1.38 ms per loop 0.10166458785533905 200 loops, best of 3: 1.39 ms per loop 0.0914674922823906 200 loops, best of 3: 1.4 ms per loop 0.09165020287036896 200 loops, best of 3: 1.41 ms per loop 0.08739049732685089 200 loops, best of 3: 1.39 ms per loop 0.03341846913099289 200 loops, best of 3: 1.38 ms per loop 0.03516167402267456 200 loops, best of 3: 1.39 ms per loop 0.03585590049624443 200 loops, best of 3: 1.39 ms per loop 0.03259558603167534 200 loops, best of 3: 1.39 ms per loop 0.03333032503724098 200 loops, best of 3: 1.39 ms per loop 0.03389334678649902 200 loops, best of 3: 1.39 ms per loop 0.033346012234687805 200 loops, best of 3: 1.39 ms per loop 0.03414909914135933 200 loops, best of 3: 1.39 ms per loop 0.03200623765587807 200 loops, best of 3: 1.39 ms per loop 0.032203223556280136
#q and denoiser after training the denoiser:
arrows = evaluate_denoiser(w_spaced.T) - w_spaced.T
plt.quiver(w_spaced[0,:],w_spaced[1,:],arrows[:,0],arrows[:,1])
W = sample_generator(1000)
plt.plot(W[:,0],W[:,1],'g.')
plt.axis('square')
plt.xlim([wmin,wmax])
plt.ylim([wmin,wmax]);
# this is the main training loop. For each gradient step of G I do 300 gradient steps of the denoiser
print (train_denoiser(batchsize, noise_variance, 0))
for i in range(200):
%timeit -n 300 train_denoiser(batchsize, noise_variance, learning_rate)
print (train_denoiser(batchsize, noise_variance, 0))
train_G(X_train, y_train, prior_variance, noise_variance, batchsize, learning_rate)
0.03396128490567207 300 loops, best of 3: 1.41 ms per loop 0.032568566501140594 300 loops, best of 3: 1.4 ms per loop 0.03608660027384758 300 loops, best of 3: 1.4 ms per loop 0.024923529475927353 300 loops, best of 3: 1.41 ms per loop 0.02488734759390354 300 loops, best of 3: 1.4 ms per loop 0.02942493185400963 300 loops, best of 3: 1.4 ms per loop 0.032803237438201904 300 loops, best of 3: 1.4 ms per loop 0.02923358604311943 300 loops, best of 3: 1.39 ms per loop 0.03560774400830269 300 loops, best of 3: 1.4 ms per loop 0.03074350580573082 300 loops, best of 3: 1.4 ms per loop 0.032162342220544815 300 loops, best of 3: 1.41 ms per loop 0.033999182283878326 300 loops, best of 3: 1.4 ms per loop 0.02323346585035324 300 loops, best of 3: 1.4 ms per loop 0.03073800913989544 300 loops, best of 3: 1.4 ms per loop 0.03788650035858154 300 loops, best of 3: 1.4 ms per loop 0.027118148282170296 300 loops, best of 3: 1.4 ms per loop 0.031182188540697098 300 loops, best of 3: 1.4 ms per loop 0.02685633674263954 300 loops, best of 3: 1.4 ms per loop 0.028946980834007263 300 loops, best of 3: 1.4 ms per loop 0.04029335826635361 300 loops, best of 3: 1.4 ms per loop 0.026890506967902184 300 loops, best of 3: 1.4 ms per loop 0.030297480523586273 300 loops, best of 3: 1.4 ms per loop 0.028452230617403984 300 loops, best of 3: 1.4 ms per loop 0.03992656245827675 300 loops, best of 3: 1.39 ms per loop 0.03307288885116577 300 loops, best of 3: 1.4 ms per loop 0.045142434537410736 300 loops, best of 3: 1.4 ms per loop 0.03762732073664665 300 loops, best of 3: 1.4 ms per loop 0.03792111948132515 300 loops, best of 3: 1.4 ms per loop 0.034627389162778854 300 loops, best of 3: 1.4 ms per loop 0.0413859561085701 300 loops, best of 3: 1.39 ms per loop 0.04524998739361763 300 loops, best of 3: 1.4 ms per loop 0.041134823113679886 300 loops, best of 3: 1.39 ms per loop 0.040777791291475296 300 loops, best of 3: 1.4 ms per loop 0.04624238237738609 300 loops, best of 3: 1.4 ms per loop 0.050264015793800354 300 loops, best of 3: 1.4 ms per loop 0.05027718469500542 300 loops, best of 3: 1.39 ms per loop 0.05398622527718544 300 loops, best of 3: 1.39 ms per loop 0.05535312369465828 300 loops, best of 3: 1.38 ms per loop 0.05758748948574066 300 loops, best of 3: 1.39 ms per loop 0.048961177468299866 300 loops, best of 3: 1.4 ms per loop 0.05939023941755295 300 loops, best of 3: 1.39 ms per loop 0.06802841275930405 300 loops, best of 3: 1.4 ms per loop 0.06428520381450653 300 loops, best of 3: 1.4 ms per loop 0.05688617751002312 300 loops, best of 3: 1.39 ms per loop 0.06123735383152962 300 loops, best of 3: 1.4 ms per loop 0.06301885098218918 300 loops, best of 3: 1.4 ms per loop 0.0682552233338356 300 loops, best of 3: 1.39 ms per loop 0.06340658664703369 300 loops, best of 3: 1.39 ms per loop 0.0653151124715805 300 loops, best of 3: 1.39 ms per loop 0.07120927423238754 300 loops, best of 3: 1.39 ms per loop 0.06310760974884033 300 loops, best of 3: 1.4 ms per loop 0.05846376717090607 300 loops, best of 3: 1.39 ms per loop 0.07364299893379211 300 loops, best of 3: 1.39 ms per loop 0.08506754040718079 300 loops, best of 3: 1.39 ms per loop 0.08917564898729324 300 loops, best of 3: 1.39 ms per loop 0.06780104339122772 300 loops, best of 3: 1.39 ms per loop 0.08782292902469635 300 loops, best of 3: 1.39 ms per loop 0.08692274242639542 300 loops, best of 3: 1.39 ms per loop 0.06845015287399292 300 loops, best of 3: 1.4 ms per loop 0.08270128071308136 300 loops, best of 3: 1.4 ms per loop 0.07914984971284866 300 loops, best of 3: 1.4 ms per loop 0.09641455113887787 300 loops, best of 3: 1.4 ms per loop 0.0798228457570076 300 loops, best of 3: 1.4 ms per loop 0.07441923022270203 300 loops, best of 3: 1.39 ms per loop 0.08600138127803802 300 loops, best of 3: 1.4 ms per loop 0.08412422239780426 300 loops, best of 3: 1.4 ms per loop 0.07957743853330612 300 loops, best of 3: 1.4 ms per loop 0.08229051530361176 300 loops, best of 3: 1.4 ms per loop 0.09904056042432785 300 loops, best of 3: 1.41 ms per loop 0.08124158531427383 300 loops, best of 3: 1.4 ms per loop 0.08354249596595764 300 loops, best of 3: 1.4 ms per loop 0.09296374022960663 300 loops, best of 3: 1.4 ms per loop 0.08404916524887085 300 loops, best of 3: 1.4 ms per loop 0.08790811151266098 300 loops, best of 3: 1.41 ms per loop 0.08504394441843033 300 loops, best of 3: 1.4 ms per loop 0.06763168424367905 300 loops, best of 3: 1.4 ms per loop 0.0911519005894661 300 loops, best of 3: 1.4 ms per loop 0.08047123998403549 300 loops, best of 3: 1.4 ms per loop 0.0841132327914238 300 loops, best of 3: 1.4 ms per loop 0.0950591191649437 300 loops, best of 3: 1.4 ms per loop 0.08686599880456924 300 loops, best of 3: 1.39 ms per loop 0.09515555202960968 300 loops, best of 3: 1.4 ms per loop 0.07525743544101715 300 loops, best of 3: 1.4 ms per loop 0.08036425709724426 300 loops, best of 3: 1.4 ms per loop 0.09251479804515839 300 loops, best of 3: 1.39 ms per loop 0.09847847372293472 300 loops, best of 3: 1.39 ms per loop 0.08043470233678818 300 loops, best of 3: 1.4 ms per loop 0.07027923315763474 300 loops, best of 3: 1.4 ms per loop 0.08357439190149307 300 loops, best of 3: 1.39 ms per loop 0.10282501578330994 300 loops, best of 3: 1.39 ms per loop 0.08972632884979248 300 loops, best of 3: 1.39 ms per loop 0.0935291051864624 300 loops, best of 3: 1.39 ms per loop 0.07356192916631699 300 loops, best of 3: 1.4 ms per loop 0.08356104791164398 300 loops, best of 3: 1.39 ms per loop 0.07177669554948807 300 loops, best of 3: 1.39 ms per loop 0.07347175478935242 300 loops, best of 3: 1.39 ms per loop 0.08979307860136032 300 loops, best of 3: 1.39 ms per loop 0.08481679856777191 300 loops, best of 3: 1.4 ms per loop 0.09422954171895981 300 loops, best of 3: 1.4 ms per loop 0.07186908274888992 300 loops, best of 3: 1.4 ms per loop 0.07950905710458755 300 loops, best of 3: 1.4 ms per loop 0.0879305973649025 300 loops, best of 3: 1.39 ms per loop 0.07631480693817139 300 loops, best of 3: 1.39 ms per loop 0.08409778773784637 300 loops, best of 3: 1.39 ms per loop 0.09041862934827805 300 loops, best of 3: 1.4 ms per loop 0.08308026939630508 300 loops, best of 3: 1.4 ms per loop 0.07612073421478271 300 loops, best of 3: 1.4 ms per loop 0.0870872288942337 300 loops, best of 3: 1.4 ms per loop 0.08048402518033981 300 loops, best of 3: 1.4 ms per loop 0.08267103135585785 300 loops, best of 3: 1.4 ms per loop 0.07473848760128021 300 loops, best of 3: 1.4 ms per loop 0.08062736690044403 300 loops, best of 3: 1.4 ms per loop 0.08244546502828598 300 loops, best of 3: 1.39 ms per loop 0.07675951719284058 300 loops, best of 3: 1.39 ms per loop 0.08993159979581833 300 loops, best of 3: 1.4 ms per loop 0.10201743990182877 300 loops, best of 3: 1.39 ms per loop 0.07002479583024979 300 loops, best of 3: 1.39 ms per loop 0.0879007950425148 300 loops, best of 3: 1.4 ms per loop 0.08066928386688232 300 loops, best of 3: 1.4 ms per loop 0.07711830735206604 300 loops, best of 3: 1.4 ms per loop 0.08316199481487274 300 loops, best of 3: 1.4 ms per loop 0.08432149887084961 300 loops, best of 3: 1.39 ms per loop 0.07514102756977081 300 loops, best of 3: 1.4 ms per loop 0.07676677405834198 300 loops, best of 3: 1.4 ms per loop 0.07633765041828156 300 loops, best of 3: 1.4 ms per loop 0.07329434156417847 300 loops, best of 3: 1.39 ms per loop 0.0796092227101326 300 loops, best of 3: 1.4 ms per loop 0.07988110929727554 300 loops, best of 3: 1.39 ms per loop 0.09203517436981201 300 loops, best of 3: 1.39 ms per loop 0.08356881141662598 300 loops, best of 3: 1.4 ms per loop 0.06827879697084427 300 loops, best of 3: 1.4 ms per loop 0.07424984127283096 300 loops, best of 3: 1.39 ms per loop 0.08614380657672882 300 loops, best of 3: 1.38 ms per loop 0.07692419737577438 300 loops, best of 3: 1.39 ms per loop 0.0771556869149208 300 loops, best of 3: 1.4 ms per loop 0.07313736528158188 300 loops, best of 3: 1.39 ms per loop 0.07966501265764236 300 loops, best of 3: 1.39 ms per loop 0.08262979239225388 300 loops, best of 3: 1.39 ms per loop 0.08363748341798782 300 loops, best of 3: 1.39 ms per loop 0.09646577388048172 300 loops, best of 3: 1.4 ms per loop 0.08562388271093369 300 loops, best of 3: 1.4 ms per loop 0.09346979856491089 300 loops, best of 3: 1.39 ms per loop 0.0958670899271965 300 loops, best of 3: 1.39 ms per loop 0.08174880594015121 300 loops, best of 3: 1.39 ms per loop 0.086372010409832 300 loops, best of 3: 1.39 ms per loop 0.07725545018911362 300 loops, best of 3: 1.39 ms per loop 0.07761067152023315 300 loops, best of 3: 1.39 ms per loop 0.07758107781410217 300 loops, best of 3: 1.39 ms per loop 0.08316235989332199 300 loops, best of 3: 1.4 ms per loop 0.07751284539699554 300 loops, best of 3: 1.39 ms per loop 0.08210666477680206 300 loops, best of 3: 1.39 ms per loop 0.08407726138830185 300 loops, best of 3: 1.39 ms per loop 0.0727674663066864 300 loops, best of 3: 1.4 ms per loop 0.07222465425729752 300 loops, best of 3: 1.4 ms per loop 0.08534356206655502 300 loops, best of 3: 1.4 ms per loop 0.07823126018047333 300 loops, best of 3: 1.4 ms per loop 0.09607892483472824 300 loops, best of 3: 1.4 ms per loop 0.08277835696935654 300 loops, best of 3: 1.39 ms per loop 0.07491978257894516 300 loops, best of 3: 1.4 ms per loop 0.09639574587345123 300 loops, best of 3: 1.4 ms per loop 0.07699526846408844 300 loops, best of 3: 1.4 ms per loop 0.09386611729860306 300 loops, best of 3: 1.4 ms per loop 0.07949307560920715 300 loops, best of 3: 1.4 ms per loop 0.06837311387062073 300 loops, best of 3: 1.4 ms per loop 0.08114217221736908 300 loops, best of 3: 1.4 ms per loop 0.07512146979570389 300 loops, best of 3: 1.4 ms per loop 0.08527648448944092 300 loops, best of 3: 1.4 ms per loop 0.07249119132757187 300 loops, best of 3: 1.4 ms per loop 0.07609005272388458 300 loops, best of 3: 1.4 ms per loop 0.07769237458705902 300 loops, best of 3: 1.4 ms per loop 0.08970976620912552 300 loops, best of 3: 1.4 ms per loop 0.08652631938457489 300 loops, best of 3: 1.4 ms per loop 0.10616575926542282 300 loops, best of 3: 1.4 ms per loop 0.08500508964061737 300 loops, best of 3: 1.4 ms per loop 0.0739007443189621 300 loops, best of 3: 1.4 ms per loop 0.08844205737113953 300 loops, best of 3: 1.4 ms per loop 0.0885847806930542 300 loops, best of 3: 1.4 ms per loop 0.0841735228896141 300 loops, best of 3: 1.4 ms per loop 0.09098771214485168 300 loops, best of 3: 1.4 ms per loop 0.06608383357524872 300 loops, best of 3: 1.39 ms per loop 0.10133348405361176 300 loops, best of 3: 1.39 ms per loop 0.08115717023611069 300 loops, best of 3: 1.39 ms per loop 0.09488784521818161 300 loops, best of 3: 1.4 ms per loop 0.09126176685094833 300 loops, best of 3: 1.39 ms per loop 0.07057154178619385 300 loops, best of 3: 1.39 ms per loop 0.07498962432146072 300 loops, best of 3: 1.39 ms per loop 0.0723327174782753 300 loops, best of 3: 1.39 ms per loop 0.08639033883810043 300 loops, best of 3: 1.4 ms per loop 0.08023995906114578 300 loops, best of 3: 1.4 ms per loop 0.07195691019296646 300 loops, best of 3: 1.39 ms per loop 0.08662775158882141 300 loops, best of 3: 1.39 ms per loop 0.09215154498815536 300 loops, best of 3: 1.39 ms per loop 0.09141266345977783 300 loops, best of 3: 1.39 ms per loop 0.09523195773363113 300 loops, best of 3: 1.39 ms per loop 0.07787106931209564 300 loops, best of 3: 1.39 ms per loop 0.086641326546669 300 loops, best of 3: 1.39 ms per loop 0.08353196084499359 300 loops, best of 3: 1.39 ms per loop 0.06516401469707489 300 loops, best of 3: 1.39 ms per loop 0.07061682641506195 300 loops, best of 3: 1.39 ms per loop 0.07531394809484482 300 loops, best of 3: 1.4 ms per loop 0.08097648620605469
#q and denoiser after training the denoiser:
sns.set_style('whitegrid')
plt.subplots(figsize=(16, 16))
arrows_sum, arrows_1, arrows_2, arrows_3 = evaluate_gradients_G(w_spaced.T, X_train, y_train, prior_variance, noise_variance)
W = sample_generator(200)
plt.subplot(2,2,1)
plt.quiver(w_spaced[0,:],w_spaced[1,:],0.1*arrows_sum[:,0],0.1*arrows_sum[:,1])
plt.plot(W[:,0],W[:,1],'g.')
plt.axis('square')
plt.xlim([wmin,wmax])
plt.ylim([wmin,wmax]);
plt.xticks([])
plt.yticks([])
plt.title('overall gradients')
plt.subplot(2,2,2)
plt.quiver(w_spaced[0,:],w_spaced[1,:],arrows_1[:,0],arrows_1[:,1])
plt.plot(W[:,0],W[:,1],'g.')
plt.axis('square')
plt.xlim([wmin,wmax])
plt.ylim([wmin,wmax]);
plt.xticks([])
plt.yticks([])
plt.title('entropy(from denoiser)')
plt.subplot(2,2,3)
plt.quiver(w_spaced[0,:],w_spaced[1,:],arrows_2[:,0],arrows_2[:,1])
plt.plot(W[:,0],W[:,1],'g.')
plt.axis('square')
plt.xlim([wmin,wmax])
plt.ylim([wmin,wmax]);
plt.xticks([])
plt.yticks([]);
plt.title('prior');
plt.subplot(2,2,4)
plt.quiver(w_spaced[0,:],w_spaced[1,:],arrows_3[:,0],arrows_3[:,1])
plt.plot(W[:,0],W[:,1],'g.')
plt.axis('square')
plt.xlim([wmin,wmax])
plt.ylim([wmin,wmax]);
plt.xticks([])
plt.yticks([]);
plt.title('likelihood');
W = sample_generator(100)
arrows_sum, arrows_1, arrows_2, arrows_3 = evaluate_gradients_G(W, X_train, y_train, prior_variance, noise_variance)
sns.set_style('whitegrid')
plt.subplots(figsize=(16,16))
arrow_scale = 3e-1
plt.subplot(2,2,1)
plt.plot(W[:,0],W[:,1],'g.',alpha=0.3)
plt.quiver(W[:,0], W[:,1], arrows_sum[:,0], arrows_sum[:,1], scale = arrow_scale, scale_units='width', width=0.005)
plt.axis('square')
plt.xlim([wmin,wmax])
plt.ylim([wmin,wmax]);
plt.xticks([])
plt.yticks([])
plt.title('overall gradients')
plt.subplot(2,2,2)
plt.plot(W[:,0],W[:,1],'g.',alpha=0.3)
plt.quiver(W[:,0], W[:,1], arrows_1[:,0], arrows_1[:,1], scale=arrow_scale, scale_units='width', width=0.005)
plt.axis('square')
plt.xlim([wmin,wmax])
plt.ylim([wmin,wmax]);
plt.xticks([])
plt.yticks([])
plt.title('entropy (from denoiser)')
plt.subplot(2,2,3)
plt.plot(W[:,0],W[:,1],'g.',alpha=0.3)
plt.quiver(W[:,0], W[:,1], arrows_2[:,0], arrows_2[:,1], scale = arrow_scale, scale_units='width', width=0.005)
plt.axis('square')
plt.xlim([wmin,wmax])
plt.ylim([wmin,wmax]);
plt.xticks([])
plt.yticks([]);
plt.title('prior');
plt.subplot(2,2,4)
plt.plot(W[:,0],W[:,1],'g.',alpha=0.3)
plt.quiver(W[:,0], W[:,1], arrows_3[:,0], arrows_3[:,1], scale = arrow_scale, scale_units='width', width=0.005)
plt.axis('square')
plt.xlim([wmin,wmax])
plt.ylim([wmin,wmax]);
plt.xticks([])
plt.yticks([]);
plt.title('likelihood');
plt.contourf(wrange, wrange, np.exp(logpost.reshape(300,300).T),cmap='gray');
plt.axis('square');
W = sample_generator(1000)
plt.plot(W[:,0],W[:,1],'.g')
plt.xlim([wmin,wmax])
plt.ylim([wmin,wmax]);
plt.title('true log posterior');
sns.set_style('whitegrid')
plt.subplot(1,2,2)
W = sample_generator(5000)
plot = sns.kdeplot(W[:,0],W[:,1])
plt.axis('square')
plot.set(xlim=(wmin,wmax))
plot.set(ylim=(wmin,wmax))
plt.title('kde of approximate posterior')
plt.subplot(1,2,1)
plt.contourf(wrange, wrange, np.exp(logpost.reshape(300,300).T),cmap='gray');
plt.axis('square');
W = sample_generator(100)
plt.plot(W[:,0],W[:,1],'.g')
plt.xlim([wmin,wmax])
plt.ylim([wmin,wmax]);
plt.title('true posterior');
#looking at average gradient magnitudes
#at convergence, the average entropy contribution should be roughly the same as prior and likelihood combined.
W = sample_generator(100)
arrows_sum, arrows_1, arrows_2, arrows_3 = evaluate_gradients_G(W, X_train, y_train, prior_variance, noise_variance)
print('entropy:',(arrows_1**2).mean())
print('log-prior:',(arrows_2**2).mean())
print('log-likelihood:',(arrows_3**2).mean())
entropy: 0.000157828 log-prior: 0.000114875 log-likelihood: 6.52165e-05