In [1]:
%matplotlib inline
from matplotlib import pylab as plt
import theano
from theano import tensor as T
import numpy as np
Using gpu device 0: Tesla K80 (CNMeM is disabled, CuDNN 4007)
In [2]:
#create a regular grid in weight space for visualisation
wmin = -5
wmax = 5
wrange = np.linspace(wmin,wmax,300)
w = np.repeat(wrange[:,None],300,axis=1)
w = np.concatenate([[w.flatten()],[w.T.flatten()]])
In [3]:
prior_variance = 2
logprior = -(w**2).sum(axis=0)/2/prior_variance
plt.contourf(wrange, wrange, logprior.reshape(300,300), cmap='gray');
plt.axis('square');
plt.xlim([wmin,wmax])
plt.ylim([wmin,wmax]);
In [4]:
#generating a toy dataset with three manually selected observations
from scipy.stats import logistic
sigmoid = logistic.cdf
logsigmoid = logistic.logcdf

def likelihood(w,x,b=0,y=1):
    return logsigmoid(y*(np.dot(w.T,x) + b)).flatten()


x1 = np.array([[1.5],[1]])
x2 = np.array([[-1.5],[1]])
x3 = np.array([[0.5],[-1]])

y1=1
y2=1
y3=-1

llh1 = likelihood(w, x1, y=y1) 
llh2 = likelihood(w, x2, y=y2) 
llh3 = likelihood(w, x3, y=y3) 
In [5]:
#calculating unnormalised log posterior
#this is only for illustration
logpost = llh1 + llh2 + llh3 + logprior
In [6]:
#plotting the real log posterior
#the red dots show the three datapoints, the small line segments shows the direction
#in which the corresponding label shifts the posterior. Positive datapoints shift the 
# posterior away from zero in the direction of the datapoint, negative datapoints shift
# away from zero, in the opposite direction.

plt.contourf(wrange,
             wrange,
             np.exp(logpost.reshape(300,300).T),cmap='gray');
plt.axis('square');
plt.xlim([wmin,wmax])
plt.ylim([wmin,wmax])

plt.plot(x1[0],x1[1],'.r')
plt.plot([x1[0],x1[0]*(1+0.2*y1)],[x1[1],x1[1]*(1+0.2*y1)],'r-')

plt.plot(x2[0],x2[1],'.r')
plt.plot([x2[0],x2[0]*(1+0.2*y2)],[x2[1],x2[1]*(1+0.2*y2)],'r-')

plt.plot(x3[0],x3[1],'.r')
plt.plot([x3[0],x3[0]*(1+0.2*y3)],[x3[1],x3[1]*(1+0.2*y3)],'r-');

Fitting an approximate posterior

This part is for the actual GAN stuff. Here we define the generator and the discriminator networks in Lasagne, and code up the two loss functions in theano.

In [7]:
from lasagne.utils import floatX

from lasagne.layers import (
    InputLayer,
    DenseLayer,
    NonlinearityLayer)
from lasagne.nonlinearities import sigmoid

#defines a 'generator' network
def build_G(input_var=None, num_z = 3):
    
    network = InputLayer(input_var=input_var, shape=(None, num_z))
    
    network = DenseLayer(incoming = network, num_units=10)
    
    network = DenseLayer(incoming = network, num_units=20)
    
    network = DenseLayer(incoming = network, num_units=2, nonlinearity=None)
    
    return network

#defines the 'discriminator network'
def build_D(input_var=None):

    network = InputLayer(input_var=input_var, shape = (None, 2))
    
    network = DenseLayer(incoming = network, num_units=10)
    
    network = DenseLayer(incoming = network, num_units=20)
    
    network = DenseLayer(incoming = network, num_units=1, nonlinearity=None)
    
    normalised = NonlinearityLayer(incoming = network, nonlinearity = sigmoid)
    
    return { 'unnorm':network, 'norm':normalised }
In [8]:
from lasagne.layers import get_output, get_all_params
from theano.printing import debugprint
from lasagne.updates import adam
from theano.tensor.shared_randomstreams import RandomStreams

#variables for input (design matrix), output labels, GAN noise variable, weights
x_var = T.matrix('design matrix')
y_var = T.vector('labels')
z_var = T.matrix('GAN noise')
w_var = T.matrix('weights')

#theano variables for things like batchsize, learning rate, etc.
batchsize_var = T.scalar('batchsize', dtype='int32')
prior_variance_var = T.scalar('prior variance')
learningrate_var = T.scalar('learning rate')

#random numbers for sampling from the prior or from the GAN
srng = RandomStreams(seed=1337)
z_rnd = srng.normal((batchsize_var,3))
prior_rnd = srng.normal((batchsize_var,2))

#instantiating the G and D networks
generator = build_G(z_var)
discriminator = build_D()

#these expressions are random samples from the generator and the prior, respectively
samples_from_grenerator = get_output(generator, z_rnd)
samples_from_prior = prior_rnd*T.sqrt(prior_variance_var)

#discriminator output for synthetic samples, both normalised and unnormalised (after/before sigmoid)
D_of_G = get_output(discriminator['norm'], inputs=samples_from_grenerator)
s_of_G = get_output(discriminator['unnorm'], inputs=samples_from_grenerator)

#discriminator output for real samples from the prior
D_of_prior = get_output(discriminator['norm'], inputs=samples_from_prior)

#loss of discriminator - simple binary cross-entropy loss
loss_D = -T.log(D_of_G).mean() - T.log(1-D_of_prior).mean()

#log likelihood for each synthetic w sampled from the generator
log_likelihood = T.log(
    T.nnet.sigmoid(
        (y_var.dimshuffle(0,'x','x')*(x_var.dimshuffle(0,1,'x') * samples_from_grenerator.dimshuffle('x', 1, 0))).sum(1)
    )
).sum(0).mean()

#loss for G is the sum of unnormalised discriminator output and the negative log likelihood
loss_G = s_of_G.mean() - log_likelihood

#compiling theano functions:
evaluate_generator = theano.function(
    [z_var],
    get_output(generator),
    allow_input_downcast=True
)

sample_generator = theano.function(
    [batchsize_var],
    samples_from_grenerator,
    allow_input_downcast=True,
)

sample_prior = theano.function(
    [prior_variance_var, batchsize_var],
    samples_from_prior,
    allow_input_downcast=True
)

params_D = get_all_params(discriminator['norm'], trainable=True)

updates_D = adam(
    loss_D,
    params_D,
    learning_rate = learningrate_var
)

train_D = theano.function(
    [learningrate_var, batchsize_var, prior_variance_var],
    loss_D,
    updates = updates_D,
    allow_input_downcast = True
)

params_G = get_all_params(generator, trainable=True)

updates_G = adam(
    loss_G,
    params_G,
    learning_rate = learningrate_var
)

train_G = theano.function(
    [x_var, y_var, learningrate_var, batchsize_var],
    loss_G,
    updates = updates_G,
    allow_input_downcast = True
)

evaluate_discriminator = theano.function(
    [w_var],
    get_output([discriminator['unnorm'],discriminator['norm']],w_var),
    allow_input_downcast = True
)

#this is to evaluate the log-likelihood of an arbitrary set of w
llh_for_w = T.nnet.sigmoid((y_var.dimshuffle(0,'x','x')*(x_var.dimshuffle(0,1,'x') * w_var.dimshuffle('x', 1, 0))).sum(1))

evaluate_loglikelihood = theano.function(
        [x_var, y_var, w_var],
        llh_for_w,
        allow_input_downcast = True
    )
In [9]:
#checking that theano and numpy give the same likelihoods
import seaborn as sns
sns.set_context('poster')

X_train = np.concatenate([x1,x2,x3],axis=1).T
y_train = np.array([y1,y2,y3])
llh_theano = evaluate_loglikelihood(X_train, y_train, w.T)

plt.figure(figsize=(16,8))
plt.subplot(1,2,1)
plt.contourf(wrange, wrange ,np.log(llh_theano).sum(0).reshape(300,300).T,cmap='gray');
plt.axis('square');
plt.xlim([wmin,wmax])
plt.ylim([wmin,wmax])
plt.title('theano loglikelihood')

plt.subplot(1,2,2)
plt.contourf(wrange, wrange, (llh1+llh2+llh3).reshape(300,300).T,cmap='gray');
plt.axis('square');
plt.xlim([wmin,wmax])
plt.ylim([wmin,wmax])
plt.title('numpy loglikelihood')

assert np.allclose(llh1+llh2+llh3,np.log(llh_theano).sum(0))
In [14]:
batchsize = 200
KARPATHY_CONSTANT = 3e-4
learning_rate = KARPATHY_CONSTANT*10

prior_variance = 2
# train discriminator for some time before starting iterative process
%timeit -n 300 train_D(learning_rate, batchsize, prior_variance)
# print initial values of training errors
print (train_D(0, 100, prior_variance))
print (train_G(X_train, y_train, 0, batchsize))
for i in range(200):
    %timeit -n 50 train_D(learning_rate, batchsize, prior_variance)
    print (train_D(0, 100, prior_variance))
    print (train_G(X_train, y_train, 0, 100))
    train_G(X_train, y_train, learning_rate, batchsize)
300 loops, best of 3: 1.52 ms per loop
0.8448745012283325
1.9514195919036865
50 loops, best of 3: 1.51 ms per loop
0.9600536823272705
2.1756317615509033
50 loops, best of 3: 1.52 ms per loop
1.1132714748382568
2.1207809448242188
50 loops, best of 3: 1.51 ms per loop
0.9374406337738037
2.033437490463257
50 loops, best of 3: 1.51 ms per loop
0.9812651872634888
2.122642993927002
50 loops, best of 3: 1.5 ms per loop
0.9524849653244019
2.1304209232330322
50 loops, best of 3: 1.52 ms per loop
0.9748408198356628
1.9830398559570312
50 loops, best of 3: 1.52 ms per loop
0.9469156265258789
2.034944772720337
50 loops, best of 3: 1.52 ms per loop
1.0123794078826904
2.0215725898742676
50 loops, best of 3: 1.51 ms per loop
0.9374747276306152
1.8552474975585938
50 loops, best of 3: 1.52 ms per loop
0.990984320640564
1.9634348154067993
50 loops, best of 3: 1.52 ms per loop
0.8005599975585938
2.0340428352355957
50 loops, best of 3: 1.5 ms per loop
1.0441194772720337
2.078402042388916
50 loops, best of 3: 1.51 ms per loop
0.9515339732170105
2.0935752391815186
50 loops, best of 3: 1.51 ms per loop
1.014158010482788
1.9193150997161865
50 loops, best of 3: 1.5 ms per loop
0.9064014554023743
1.9159090518951416
50 loops, best of 3: 1.5 ms per loop
0.8740888833999634
1.8587417602539062
50 loops, best of 3: 1.5 ms per loop
1.045501947402954
2.0620100498199463
50 loops, best of 3: 1.49 ms per loop
1.102818489074707
1.9564135074615479
50 loops, best of 3: 1.5 ms per loop
0.9090922474861145
2.018059015274048
50 loops, best of 3: 1.5 ms per loop
0.9764355421066284
1.9801956415176392
50 loops, best of 3: 1.5 ms per loop
0.9260282516479492
1.9836033582687378
50 loops, best of 3: 1.5 ms per loop
1.0356860160827637
2.0456244945526123
50 loops, best of 3: 1.51 ms per loop
0.941687822341919
1.9629082679748535
50 loops, best of 3: 1.5 ms per loop
0.9731559753417969
1.988584041595459
50 loops, best of 3: 1.5 ms per loop
1.0802245140075684
2.0056726932525635
50 loops, best of 3: 1.5 ms per loop
1.0316039323806763
1.9496046304702759
50 loops, best of 3: 1.5 ms per loop
0.9684480428695679
1.869762897491455
50 loops, best of 3: 1.5 ms per loop
1.0378801822662354
1.902108073234558
50 loops, best of 3: 1.49 ms per loop
1.0136998891830444
1.9007588624954224
50 loops, best of 3: 1.51 ms per loop
1.0407465696334839
1.9785401821136475
50 loops, best of 3: 1.51 ms per loop
1.032365083694458
1.8645954132080078
50 loops, best of 3: 1.5 ms per loop
0.9352719187736511
1.9393141269683838
50 loops, best of 3: 1.5 ms per loop
1.0411553382873535
1.9759348630905151
50 loops, best of 3: 1.51 ms per loop
0.8414285778999329
2.059307098388672
50 loops, best of 3: 1.51 ms per loop
0.980594277381897
1.9248079061508179
50 loops, best of 3: 1.51 ms per loop
0.9445054531097412
1.9292949438095093
50 loops, best of 3: 1.51 ms per loop
1.045969843864441
1.9180612564086914
50 loops, best of 3: 1.5 ms per loop
1.0355181694030762
1.9287148714065552
50 loops, best of 3: 1.5 ms per loop
1.0907528400421143
2.0063419342041016
50 loops, best of 3: 1.52 ms per loop
1.030317783355713
2.0014772415161133
50 loops, best of 3: 1.51 ms per loop
0.962618887424469
1.8605748414993286
50 loops, best of 3: 1.51 ms per loop
0.9774050116539001
1.910995602607727
50 loops, best of 3: 1.51 ms per loop
0.9333393573760986
1.9094206094741821
50 loops, best of 3: 1.52 ms per loop
0.9691608548164368
1.8858442306518555
50 loops, best of 3: 1.52 ms per loop
0.8662000894546509
1.9423495531082153
50 loops, best of 3: 1.52 ms per loop
0.8949613571166992
1.9160100221633911
50 loops, best of 3: 1.52 ms per loop
0.9095511436462402
1.924254298210144
50 loops, best of 3: 1.53 ms per loop
1.0445687770843506
1.9673411846160889
50 loops, best of 3: 1.51 ms per loop
0.978882908821106
1.8847665786743164
50 loops, best of 3: 1.51 ms per loop
0.9121913909912109
1.89730966091156
50 loops, best of 3: 1.5 ms per loop
0.9401355385780334
1.8624334335327148
50 loops, best of 3: 1.5 ms per loop
0.9630769491195679
1.992774248123169
50 loops, best of 3: 1.51 ms per loop
0.9487576484680176
1.9473307132720947
50 loops, best of 3: 1.51 ms per loop
1.0263659954071045
1.89525306224823
50 loops, best of 3: 1.51 ms per loop
0.9195380806922913
1.9877896308898926
50 loops, best of 3: 1.52 ms per loop
0.9764193296432495
1.9331001043319702
50 loops, best of 3: 1.52 ms per loop
0.9087851047515869
1.8820929527282715
50 loops, best of 3: 1.52 ms per loop
0.9978067278862
1.9613386392593384
50 loops, best of 3: 1.52 ms per loop
0.985075056552887
1.9285860061645508
50 loops, best of 3: 1.52 ms per loop
0.8952118158340454
1.9643129110336304
50 loops, best of 3: 1.53 ms per loop
0.9691531658172607
1.9074223041534424
50 loops, best of 3: 1.52 ms per loop
1.018507480621338
1.9238728284835815
50 loops, best of 3: 1.52 ms per loop
1.0125757455825806
1.9139434099197388
50 loops, best of 3: 1.52 ms per loop
0.9592885971069336
1.9897146224975586
50 loops, best of 3: 1.54 ms per loop
1.001117467880249
1.9631487131118774
50 loops, best of 3: 1.5 ms per loop
1.0618159770965576
1.865829586982727
50 loops, best of 3: 1.51 ms per loop
1.0544630289077759
1.8829889297485352
50 loops, best of 3: 1.5 ms per loop
0.9327152967453003
2.012500047683716
50 loops, best of 3: 1.5 ms per loop
0.9645974636077881
1.8529531955718994
50 loops, best of 3: 1.51 ms per loop
1.0016453266143799
1.9349000453948975
50 loops, best of 3: 1.51 ms per loop
0.875522255897522
1.8374173641204834
50 loops, best of 3: 1.51 ms per loop
0.9100744724273682
1.9483352899551392
50 loops, best of 3: 1.52 ms per loop
0.8262972235679626
1.9207834005355835
50 loops, best of 3: 1.51 ms per loop
1.0230774879455566
1.8619059324264526
50 loops, best of 3: 1.51 ms per loop
0.8782156705856323
1.8719462156295776
50 loops, best of 3: 1.5 ms per loop
1.0510547161102295
1.9405272006988525
50 loops, best of 3: 1.51 ms per loop
1.0219957828521729
1.952405571937561
50 loops, best of 3: 1.51 ms per loop
1.1573123931884766
1.8239548206329346
50 loops, best of 3: 1.51 ms per loop
1.064415454864502
1.9159923791885376
50 loops, best of 3: 1.51 ms per loop
1.0200016498565674
1.8292263746261597
50 loops, best of 3: 1.5 ms per loop
0.9069321155548096
1.9733787775039673
50 loops, best of 3: 1.52 ms per loop
0.9966238141059875
1.9710220098495483
50 loops, best of 3: 1.51 ms per loop
0.9025123119354248
1.8472779989242554
50 loops, best of 3: 1.51 ms per loop
1.0376310348510742
1.8177932500839233
50 loops, best of 3: 1.51 ms per loop
1.011881947517395
1.9010998010635376
50 loops, best of 3: 1.52 ms per loop
0.9227458238601685
1.9468706846237183
50 loops, best of 3: 1.52 ms per loop
0.862939715385437
1.9307807683944702
50 loops, best of 3: 1.52 ms per loop
0.9486161470413208
1.8571929931640625
50 loops, best of 3: 1.51 ms per loop
0.8988595008850098
1.8314225673675537
50 loops, best of 3: 1.51 ms per loop
0.9509600400924683
1.8611066341400146
50 loops, best of 3: 1.51 ms per loop
0.9949578046798706
1.878565788269043
50 loops, best of 3: 1.51 ms per loop
0.8730878829956055
1.9759849309921265
50 loops, best of 3: 1.51 ms per loop
0.9754554629325867
1.9739391803741455
50 loops, best of 3: 1.53 ms per loop
0.863519549369812
1.8903473615646362
50 loops, best of 3: 1.52 ms per loop
1.0119316577911377
1.8875197172164917
50 loops, best of 3: 1.52 ms per loop
1.0236506462097168
1.9025930166244507
50 loops, best of 3: 1.53 ms per loop
0.9017845392227173
1.8345367908477783
50 loops, best of 3: 1.54 ms per loop
1.0171787738800049
1.871687889099121
50 loops, best of 3: 1.53 ms per loop
0.932150661945343
1.8545056581497192
50 loops, best of 3: 1.54 ms per loop
0.8963805437088013
1.8622294664382935
50 loops, best of 3: 1.52 ms per loop
0.9883332848548889
1.9165235757827759
50 loops, best of 3: 1.52 ms per loop
1.0065093040466309
1.9213820695877075
50 loops, best of 3: 1.53 ms per loop
0.9831069707870483
1.8940064907073975
50 loops, best of 3: 1.53 ms per loop
1.0976672172546387
1.859683632850647
50 loops, best of 3: 1.52 ms per loop
1.0611366033554077
1.838294506072998
50 loops, best of 3: 1.52 ms per loop
1.0289955139160156
1.8952610492706299
50 loops, best of 3: 1.53 ms per loop
0.8878443837165833
1.9366623163223267
50 loops, best of 3: 1.51 ms per loop
0.9999676942825317
1.9656920433044434
50 loops, best of 3: 1.51 ms per loop
1.0331406593322754
1.8805748224258423
50 loops, best of 3: 1.51 ms per loop
0.8540570735931396
1.9150948524475098
50 loops, best of 3: 1.52 ms per loop
1.053358793258667
1.962748646736145
50 loops, best of 3: 1.52 ms per loop
1.0905137062072754
1.9155274629592896
50 loops, best of 3: 1.51 ms per loop
0.9880650043487549
1.9293056726455688
50 loops, best of 3: 1.51 ms per loop
1.026888370513916
1.8844863176345825
50 loops, best of 3: 1.51 ms per loop
0.9923015832901001
1.9336037635803223
50 loops, best of 3: 1.51 ms per loop
0.8971617221832275
1.8725190162658691
50 loops, best of 3: 1.5 ms per loop
0.920857310295105
1.8894339799880981
50 loops, best of 3: 1.5 ms per loop
1.0053694248199463
1.9092943668365479
50 loops, best of 3: 1.5 ms per loop
0.9656652212142944
1.9232662916183472
50 loops, best of 3: 1.53 ms per loop
0.9977065324783325
1.8844738006591797
50 loops, best of 3: 1.51 ms per loop
1.2036333084106445
1.898746132850647
50 loops, best of 3: 1.51 ms per loop
1.0021357536315918
1.940970778465271
50 loops, best of 3: 1.51 ms per loop
0.9431014060974121
1.9282381534576416
50 loops, best of 3: 1.51 ms per loop
0.9904188513755798
1.904990553855896
50 loops, best of 3: 1.51 ms per loop
0.9765397310256958
1.8933037519454956
50 loops, best of 3: 1.5 ms per loop
0.9893286228179932
1.951127052307129
50 loops, best of 3: 1.5 ms per loop
0.8894766569137573
1.9410812854766846
50 loops, best of 3: 1.51 ms per loop
0.8829526305198669
1.8478164672851562
50 loops, best of 3: 1.5 ms per loop
0.9760034084320068
1.8913582563400269
50 loops, best of 3: 1.5 ms per loop
0.8971577882766724
1.9102680683135986
50 loops, best of 3: 1.54 ms per loop
0.9775246381759644
1.977861762046814
50 loops, best of 3: 1.51 ms per loop
0.91898113489151
1.900395154953003
50 loops, best of 3: 1.5 ms per loop
1.0857927799224854
1.9142166376113892
50 loops, best of 3: 1.49 ms per loop
0.893352746963501
1.937159538269043
50 loops, best of 3: 1.5 ms per loop
0.9891102313995361
1.8830279111862183
50 loops, best of 3: 1.5 ms per loop
0.9243213534355164
1.9274885654449463
50 loops, best of 3: 1.51 ms per loop
1.0036383867263794
1.9389046430587769
50 loops, best of 3: 1.51 ms per loop
0.7969959378242493
1.9497418403625488
50 loops, best of 3: 1.51 ms per loop
1.0847009420394897
1.9572519063949585
50 loops, best of 3: 1.51 ms per loop
0.864370584487915
1.9119330644607544
50 loops, best of 3: 1.51 ms per loop
0.9453173875808716
1.96038019657135
50 loops, best of 3: 1.5 ms per loop
0.8675397634506226
1.8303234577178955
50 loops, best of 3: 1.5 ms per loop
0.9847473502159119
1.8589293956756592
50 loops, best of 3: 1.5 ms per loop
0.9194648265838623
1.915841817855835
50 loops, best of 3: 1.5 ms per loop
0.9463154077529907
1.844448447227478
50 loops, best of 3: 1.51 ms per loop
1.0360722541809082
1.858137845993042
50 loops, best of 3: 1.5 ms per loop
1.04253351688385
1.968713641166687
50 loops, best of 3: 1.49 ms per loop
1.0069658756256104
1.8724918365478516
50 loops, best of 3: 1.5 ms per loop
1.0952907800674438
1.9573665857315063
50 loops, best of 3: 1.51 ms per loop
0.802614688873291
1.8787537813186646
50 loops, best of 3: 1.5 ms per loop
0.8913273811340332
1.8159719705581665
50 loops, best of 3: 1.5 ms per loop
0.9341927766799927
1.867208480834961
50 loops, best of 3: 1.51 ms per loop
1.0589802265167236
1.9539941549301147
50 loops, best of 3: 1.5 ms per loop
0.9097493886947632
1.906331181526184
50 loops, best of 3: 1.5 ms per loop
0.9541696906089783
1.909193992614746
50 loops, best of 3: 1.5 ms per loop
0.9327772855758667
1.9232544898986816
50 loops, best of 3: 1.49 ms per loop
1.146897792816162
1.8370968103408813
50 loops, best of 3: 1.51 ms per loop
1.0738775730133057
1.8403311967849731
50 loops, best of 3: 1.49 ms per loop
0.9754031300544739
1.9976423978805542
50 loops, best of 3: 1.51 ms per loop
0.8779315948486328
1.9406893253326416
50 loops, best of 3: 1.52 ms per loop
1.0117578506469727
1.8763118982315063
50 loops, best of 3: 1.52 ms per loop
0.9227365255355835
1.8432515859603882
50 loops, best of 3: 1.54 ms per loop
0.9738056659698486
1.8978081941604614
50 loops, best of 3: 1.51 ms per loop
0.9701780080795288
1.9601927995681763
50 loops, best of 3: 1.52 ms per loop
1.0007644891738892
1.883573055267334
50 loops, best of 3: 1.52 ms per loop
0.9890692234039307
1.9037498235702515
50 loops, best of 3: 1.52 ms per loop
0.9662845134735107
1.9076638221740723
50 loops, best of 3: 1.51 ms per loop
0.9601151943206787
1.8417010307312012
50 loops, best of 3: 1.51 ms per loop
0.9368103742599487
1.8808245658874512
50 loops, best of 3: 1.53 ms per loop
0.9241345524787903
1.9551916122436523
50 loops, best of 3: 1.51 ms per loop
0.9748488664627075
1.962646722793579
50 loops, best of 3: 1.5 ms per loop
0.8990902900695801
1.9472965002059937
50 loops, best of 3: 1.51 ms per loop
0.9632441997528076
1.8204729557037354
50 loops, best of 3: 1.53 ms per loop
0.9495596885681152
1.950182318687439
50 loops, best of 3: 1.51 ms per loop
0.9088999032974243
1.8404701948165894
50 loops, best of 3: 1.51 ms per loop
0.8672885298728943
1.9343440532684326
50 loops, best of 3: 1.53 ms per loop
1.020755648612976
1.8917264938354492
50 loops, best of 3: 1.51 ms per loop
0.8814740180969238
1.8805466890335083
50 loops, best of 3: 1.51 ms per loop
0.9514379501342773
1.8690009117126465
50 loops, best of 3: 1.51 ms per loop
1.017188549041748
1.8670145273208618
50 loops, best of 3: 1.51 ms per loop
0.9998301267623901
1.9476951360702515
50 loops, best of 3: 1.51 ms per loop
0.8669408559799194
1.9175958633422852
50 loops, best of 3: 1.5 ms per loop
0.8989511132240295
1.9025661945343018
50 loops, best of 3: 1.5 ms per loop
0.9505672454833984
1.8321565389633179
50 loops, best of 3: 1.51 ms per loop
1.0252314805984497
1.8379268646240234
50 loops, best of 3: 1.5 ms per loop
0.8472341299057007
1.875765085220337
50 loops, best of 3: 1.51 ms per loop
0.9293055534362793
1.8557850122451782
50 loops, best of 3: 1.5 ms per loop
0.9614561796188354
1.8672025203704834
50 loops, best of 3: 1.5 ms per loop
0.9699405431747437
1.9759284257888794
50 loops, best of 3: 1.51 ms per loop
0.933645486831665
1.8617061376571655
50 loops, best of 3: 1.52 ms per loop
0.886096715927124
1.9568455219268799
50 loops, best of 3: 1.49 ms per loop
0.9915707111358643
1.9972354173660278
50 loops, best of 3: 1.5 ms per loop
0.9266557097434998
1.928452491760254
50 loops, best of 3: 1.51 ms per loop
0.8907764554023743
1.8821607828140259
50 loops, best of 3: 1.51 ms per loop
1.0371427536010742
1.883976697921753
50 loops, best of 3: 1.5 ms per loop
0.9632725715637207
1.9309386014938354
50 loops, best of 3: 1.5 ms per loop
1.053124189376831
1.883291244506836
50 loops, best of 3: 1.49 ms per loop
0.9354572296142578
1.922497034072876
50 loops, best of 3: 1.5 ms per loop
0.8532063961029053
1.9385792016983032
50 loops, best of 3: 1.49 ms per loop
0.9147511720657349
1.8892595767974854
In [15]:
plt.figure(figsize=(16,8))


plt.subplot(1,2,1)
plt.contourf(wrange, wrange, evaluate_discriminator(w.T)[0].reshape(300,300).T)

W = sample_generator(100)
plt.plot(W[:,0],W[:,1],'g.')

W = sample_prior(prior_variance, 100)
plt.plot(W[:,0],W[:,1],'r.')
plt.axis('square')
plt.xlim([wmin,wmax])
plt.ylim([wmin,wmax])
plt.title('estimated log density ratio $\Phi^{-1}(D)$')

plt.subplot(1,2,2)
plt.contourf(wrange, wrange, (llh1+llh2+llh3).reshape(300,300).T)
plt.axis('square')
plt.xlim([wmin,wmax])
plt.ylim([wmin,wmax]);

plt.title('log likelihood');
In [16]:
plt.contourf(wrange, wrange, np.exp(logpost.reshape(300,300).T),cmap='gray');
plt.axis('square');
W = sample_generator(1000)
plt.plot(W[:,0],W[:,1],'.g')
plt.xlim([wmin,wmax])
plt.ylim([wmin,wmax]);
plt.title('true log posterior');
In [17]:
sns.set_style('whitegrid')
plt.subplot(1,2,2)

W = sample_generator(5000)
plot = sns.kdeplot(W[:,0],W[:,1])
plt.axis('square')
plot.set(xlim=(wmin,wmax))
plot.set(ylim=(wmin,wmax))
plt.title('kde of approximate posterior')

plt.subplot(1,2,1)
plt.contourf(wrange, wrange, np.exp(logpost.reshape(300,300).T),cmap='gray');
plt.axis('square');
W = sample_generator(100)
plt.plot(W[:,0],W[:,1],'.g')
plt.xlim([wmin,wmax])
plt.ylim([wmin,wmax]);
plt.title('true posterior');
In [ ]: