In [1]:
#!/usr/bin/env python
import import_ipynb
In [2]:
import random
import numpy as np
from utils.treebank import StanfordSentiment
import matplotlib
matplotlib.use('agg')
import matplotlib.pyplot as plt
import time

from q3_word2vec import *
from q3_sgd import *

# Reset the random seed to make sure that everyone gets the same results
random.seed(314)
dataset = StanfordSentiment()
tokens = dataset.tokens()
nWords = len(tokens)

# We are going to train 10-dimensional vectors for this assignment
dimVectors = 10

# Context size
C = 5

# Reset the random seed to make sure that everyone gets the same results
random.seed(31415)
np.random.seed(9265)

startTime=time.time()
wordVectors = np.concatenate(
    ((np.random.rand(nWords, dimVectors) - 0.5) /
       dimVectors, np.zeros((nWords, dimVectors))),
    axis=0)

wordVectors = sgd(
    lambda vec: word2vec_sgd_wrapper(skipgram, tokens, vec, dataset, C,
        negSamplingCostAndGradient),
    wordVectors, 0.3, 40000, None, True, PRINT_EVERY=100)
# Note that normalization is not called here. This is not a bug,
# normalizing during training loses the notion of length.

print("sanity check: cost at convergence should be around or below 10")
print("training took %d seconds" % (time.time() - startTime))

# concatenate the input and output word vectors
# 这里将U,V合并,后面会进行奇异值分解
wordVectors = np.concatenate(
    (wordVectors[:nWords,:], wordVectors[nWords:,:]),
    axis=0)

visualizeWords = [
    "the", "a", "an", ",", ".", "?", "!", "``", "''", "--",
    "good", "great", "cool", "brilliant", "wonderful", "well", "amazing",
    "worth", "sweet", "enjoyable", "boring", "bad", "waste", "dumb",
    "annoying"]
    
visualizeIdx = [tokens[word] for word in visualizeWords]
visualizeVecs = wordVectors[visualizeIdx, :]

# PCA,采用SVD来实现的,PCA很重要的一点中心化,均值为0
temp = (visualizeVecs - np.mean(visualizeVecs, axis=0))
covariance = 1.0 / len(visualizeIdx) * temp.T.dot(temp)
# SVD的左奇异矩阵恰好就是X.dot(X.T)的特征向量组成的矩阵,而这个矩阵的特征向量恰好就是PCA的主成分
U,S,V = np.linalg.svd(covariance)
coord = temp.dot(U[:,0:2])

for i in range(len(visualizeWords)):
    plt.text(coord[i,0], coord[i,1], visualizeWords[i],
        bbox=dict(facecolor='green', alpha=0.1))

plt.xlim((np.min(coord[:,0]), np.max(coord[:,0])))
plt.ylim((np.min(coord[:,1]), np.max(coord[:,1])))

plt.savefig('q3_word_vectors.png')
importing Jupyter notebook from q3_word2vec.ipynb
importing Jupyter notebook from q1_softmax.ipynb
importing Jupyter notebook from q2_gradcheck.ipynb
importing Jupyter notebook from q2_sigmoid.ipynb
importing Jupyter notebook from q3_sgd.ipynb
iter_ 100: 18.604024
iter_ 200: 18.779389
iter_ 300: 18.968864
iter_ 400: 19.209857
iter_ 500: 19.156670
iter_ 600: 19.243421
iter_ 700: 19.440181
iter_ 800: 19.520359
iter_ 900: 19.588859
iter_ 1000: 19.882518
iter_ 1100: 20.161624
iter_ 1200: 20.121685
iter_ 1300: 20.228471
iter_ 1400: 20.474703
iter_ 1500: 20.487733
iter_ 1600: 20.621488
iter_ 1700: 20.671889
iter_ 1800: 20.673450
iter_ 1900: 20.947727
iter_ 2000: 21.009787
iter_ 2100: 21.031895
iter_ 2200: 21.075417
iter_ 2300: 21.037911
iter_ 2400: 21.063337
iter_ 2500: 21.159010
iter_ 2600: 21.220136
iter_ 2700: 21.306704
iter_ 2800: 21.317022
iter_ 2900: 21.313567
iter_ 3000: 21.280537
iter_ 3100: 21.394261
iter_ 3200: 21.222326
iter_ 3300: 21.103933
iter_ 3400: 21.026450
iter_ 3500: 20.940565
iter_ 3600: 20.778982
iter_ 3700: 20.867870
iter_ 3800: 20.790209
iter_ 3900: 20.781813
iter_ 4000: 20.573691
iter_ 4100: 20.434695
iter_ 4200: 20.454322
iter_ 4300: 20.281818
iter_ 4400: 20.213335
iter_ 4500: 20.006197
iter_ 4600: 19.938834
iter_ 4700: 19.816390
iter_ 4800: 19.534211
iter_ 4900: 19.492875
iter_ 5000: 19.327795
saved!
iter_ 5100: 19.069202
iter_ 5200: 18.893172
iter_ 5300: 18.771077
iter_ 5400: 18.903853
iter_ 5500: 18.943519
iter_ 5600: 18.835504
iter_ 5700: 18.641866
iter_ 5800: 18.568050
iter_ 5900: 18.406251
iter_ 6000: 18.303576
iter_ 6100: 18.161562
iter_ 6200: 18.011232
iter_ 6300: 17.931368
iter_ 6400: 17.969555
iter_ 6500: 17.838386
iter_ 6600: 17.639902
iter_ 6700: 17.505160
iter_ 6800: 17.334591
iter_ 6900: 17.245499
iter_ 7000: 17.072789
iter_ 7100: 17.000454
iter_ 7200: 16.926262
iter_ 7300: 16.776886
iter_ 7400: 16.724920
iter_ 7500: 16.555643
iter_ 7600: 16.497449
iter_ 7700: 16.367106
iter_ 7800: 16.263621
iter_ 7900: 16.167857
iter_ 8000: 16.075852
iter_ 8100: 15.863419
iter_ 8200: 15.704432
iter_ 8300: 15.516366
iter_ 8400: 15.415402
iter_ 8500: 15.320215
iter_ 8600: 15.224080
iter_ 8700: 15.137174
iter_ 8800: 14.932750
iter_ 8900: 14.888707
iter_ 9000: 14.842695
iter_ 9100: 14.664533
iter_ 9200: 14.700342
iter_ 9300: 14.598106
iter_ 9400: 14.470129
iter_ 9500: 14.346605
iter_ 9600: 14.271724
iter_ 9700: 14.161178
iter_ 9800: 14.101970
iter_ 9900: 14.041289
iter_ 10000: 13.905749
saved!
iter_ 10100: 13.781994
iter_ 10200: 13.761556
iter_ 10300: 13.604726
iter_ 10400: 13.544350
iter_ 10500: 13.491717
iter_ 10600: 13.401767
iter_ 10700: 13.305987
iter_ 10800: 13.281872
iter_ 10900: 13.251243
iter_ 11000: 13.149489
iter_ 11100: 13.071243
iter_ 11200: 12.996503
iter_ 11300: 13.014213
iter_ 11400: 12.944917
iter_ 11500: 12.917141
iter_ 11600: 12.835549
iter_ 11700: 12.826709
iter_ 11800: 12.747148
iter_ 11900: 12.641782
iter_ 12000: 12.665167
iter_ 12100: 12.646362
iter_ 12200: 12.637761
iter_ 12300: 12.515719
iter_ 12400: 12.576099
iter_ 12500: 12.521075
iter_ 12600: 12.418966
iter_ 12700: 12.381063
iter_ 12800: 12.381649
iter_ 12900: 12.322723
iter_ 13000: 12.291331
iter_ 13100: 12.231207
iter_ 13200: 12.221519
iter_ 13300: 12.115059
iter_ 13400: 12.065030
iter_ 13500: 12.064713
iter_ 13600: 12.021912
iter_ 13700: 11.971597
iter_ 13800: 11.848290
iter_ 13900: 11.786099
iter_ 14000: 11.744749
iter_ 14100: 11.701911
iter_ 14200: 11.687953
iter_ 14300: 11.626524
iter_ 14400: 11.649125
iter_ 14500: 11.621814
iter_ 14600: 11.579251
iter_ 14700: 11.540343
iter_ 14800: 11.483450
iter_ 14900: 11.383964
iter_ 15000: 11.345420
saved!
iter_ 15100: 11.189060
iter_ 15200: 11.210092
iter_ 15300: 11.219047
iter_ 15400: 11.203139
iter_ 15500: 11.158248
iter_ 15600: 11.123628
iter_ 15700: 11.099226
iter_ 15800: 11.064870
iter_ 15900: 11.084253
iter_ 16000: 11.015952
iter_ 16100: 11.004468
iter_ 16200: 11.015049
iter_ 16300: 11.023876
iter_ 16400: 11.010123
iter_ 16500: 10.983821
iter_ 16600: 10.938413
iter_ 16700: 10.894742
iter_ 16800: 10.754350
iter_ 16900: 10.664717
iter_ 17000: 10.623251
iter_ 17100: 10.596035
iter_ 17200: 10.617019
iter_ 17300: 10.721184
iter_ 17400: 10.698315
iter_ 17500: 10.758545
iter_ 17600: 10.730561
iter_ 17700: 10.756100
iter_ 17800: 10.756223
iter_ 17900: 10.729578
iter_ 18000: 10.713750
iter_ 18100: 10.733265
iter_ 18200: 10.717193
iter_ 18300: 10.734548
iter_ 18400: 10.626955
iter_ 18500: 10.573120
iter_ 18600: 10.573155
iter_ 18700: 10.548022
iter_ 18800: 10.497992
iter_ 18900: 10.464389
iter_ 19000: 10.467851
iter_ 19100: 10.451004
iter_ 19200: 10.416570
iter_ 19300: 10.369180
iter_ 19400: 10.380386
iter_ 19500: 10.334510
iter_ 19600: 10.426575
iter_ 19700: 10.402202
iter_ 19800: 10.345839
iter_ 19900: 10.414973
iter_ 20000: 10.414057
saved!
iter_ 20100: 10.424682
iter_ 20200: 10.356544
iter_ 20300: 10.423831
iter_ 20400: 10.387770
iter_ 20500: 10.362605
iter_ 20600: 10.376197
iter_ 20700: 10.331289
iter_ 20800: 10.380293
iter_ 20900: 10.337848
iter_ 21000: 10.340521
iter_ 21100: 10.305953
iter_ 21200: 10.317894
iter_ 21300: 10.343518
iter_ 21400: 10.314110
iter_ 21500: 10.271524
iter_ 21600: 10.238044
iter_ 21700: 10.205756
iter_ 21800: 10.176787
iter_ 21900: 10.084072
iter_ 22000: 10.106953
iter_ 22100: 10.053215
iter_ 22200: 10.069332
iter_ 22300: 10.052441
iter_ 22400: 10.012923
iter_ 22500: 10.038104
iter_ 22600: 9.993996
iter_ 22700: 10.004109
iter_ 22800: 10.025915
iter_ 22900: 10.036278
iter_ 23000: 10.014603
iter_ 23100: 9.960675
iter_ 23200: 10.003456
iter_ 23300: 10.055802
iter_ 23400: 9.988262
iter_ 23500: 10.001681
iter_ 23600: 9.990759
iter_ 23700: 9.984735
iter_ 23800: 9.981500
iter_ 23900: 9.894283
iter_ 24000: 9.851653
iter_ 24100: 9.867821
iter_ 24200: 9.866621
iter_ 24300: 9.886243
iter_ 24400: 9.932614
iter_ 24500: 9.943179
iter_ 24600: 9.952576
iter_ 24700: 9.955699
iter_ 24800: 9.905475
iter_ 24900: 9.834542
iter_ 25000: 9.880161
saved!
iter_ 25100: 9.860923
iter_ 25200: 9.863371
iter_ 25300: 9.910072
iter_ 25400: 9.902151
iter_ 25500: 9.936475
iter_ 25600: 9.936556
iter_ 25700: 9.952919
iter_ 25800: 9.977249
iter_ 25900: 9.991944
iter_ 26000: 10.043130
iter_ 26100: 10.079640
iter_ 26200: 10.011304
iter_ 26300: 9.939285
iter_ 26400: 9.895720
iter_ 26500: 9.885912
iter_ 26600: 9.877268
iter_ 26700: 9.855755
iter_ 26800: 9.849249
iter_ 26900: 9.819325
iter_ 27000: 9.791931
iter_ 27100: 9.820158
iter_ 27200: 9.768645
iter_ 27300: 9.825911
iter_ 27400: 9.794431
iter_ 27500: 9.857503
iter_ 27600: 9.816648
iter_ 27700: 9.823728
iter_ 27800: 9.821667
iter_ 27900: 9.862476
iter_ 28000: 9.881795
iter_ 28100: 9.883599
iter_ 28200: 9.912319
iter_ 28300: 9.941006
iter_ 28400: 9.917078
iter_ 28500: 9.914267
iter_ 28600: 9.860607
iter_ 28700: 9.905190
iter_ 28800: 9.939964
iter_ 28900: 9.940590
iter_ 29000: 9.893362
iter_ 29100: 9.916780
iter_ 29200: 9.838282
iter_ 29300: 9.834859
iter_ 29400: 9.831548
iter_ 29500: 9.790570
iter_ 29600: 9.788810
iter_ 29700: 9.746504
iter_ 29800: 9.798515
iter_ 29900: 9.782321
iter_ 30000: 9.704538
saved!
iter_ 30100: 9.729074
iter_ 30200: 9.768284
iter_ 30300: 9.784937
iter_ 30400: 9.780768
iter_ 30500: 9.828559
iter_ 30600: 9.873279
iter_ 30700: 9.866925
iter_ 30800: 9.874105
iter_ 30900: 9.875424
iter_ 31000: 9.853438
iter_ 31100: 9.834655
iter_ 31200: 9.827070
iter_ 31300: 9.808807
iter_ 31400: 9.763422
iter_ 31500: 9.808970
iter_ 31600: 9.875959
iter_ 31700: 9.873590
iter_ 31800: 9.918868
iter_ 31900: 9.970307
iter_ 32000: 10.021178
iter_ 32100: 9.993948
iter_ 32200: 9.943728
iter_ 32300: 9.862252
iter_ 32400: 9.772822
iter_ 32500: 9.730033
iter_ 32600: 9.708618
iter_ 32700: 9.697634
iter_ 32800: 9.719430
iter_ 32900: 9.686365
iter_ 33000: 9.641960
iter_ 33100: 9.700967
iter_ 33200: 9.686835
iter_ 33300: 9.655312
iter_ 33400: 9.677657
iter_ 33500: 9.650491
iter_ 33600: 9.666483
iter_ 33700: 9.671914
iter_ 33800: 9.658939
iter_ 33900: 9.579504
iter_ 34000: 9.546839
iter_ 34100: 9.506833
iter_ 34200: 9.497960
iter_ 34300: 9.504882
iter_ 34400: 9.526258
iter_ 34500: 9.543517
iter_ 34600: 9.579199
iter_ 34700: 9.570666
iter_ 34800: 9.567154
iter_ 34900: 9.540781
iter_ 35000: 9.605205
saved!
iter_ 35100: 9.664605
iter_ 35200: 9.681135
iter_ 35300: 9.614818
iter_ 35400: 9.566828
iter_ 35500: 9.610944
iter_ 35600: 9.663755
iter_ 35700: 9.722588
iter_ 35800: 9.722484
iter_ 35900: 9.677905
iter_ 36000: 9.681488
iter_ 36100: 9.697359
iter_ 36200: 9.675975
iter_ 36300: 9.632207
iter_ 36400: 9.572983
iter_ 36500: 9.546814
iter_ 36600: 9.563374
iter_ 36700: 9.569601
iter_ 36800: 9.610571
iter_ 36900: 9.583856
iter_ 37000: 9.579905
iter_ 37100: 9.555246
iter_ 37200: 9.574555
iter_ 37300: 9.529198
iter_ 37400: 9.501708
iter_ 37500: 9.510385
iter_ 37600: 9.576810
iter_ 37700: 9.520415
iter_ 37800: 9.561922
iter_ 37900: 9.574543
iter_ 38000: 9.605944
iter_ 38100: 9.620448
iter_ 38200: 9.662221
iter_ 38300: 9.625025
iter_ 38400: 9.581447
iter_ 38500: 9.615949
iter_ 38600: 9.606902
iter_ 38700: 9.663988
iter_ 38800: 9.608830
iter_ 38900: 9.631785
iter_ 39000: 9.638900
iter_ 39100: 9.590407
iter_ 39200: 9.596133
iter_ 39300: 9.532854
iter_ 39400: 9.522097
iter_ 39500: 9.497443
iter_ 39600: 9.458520
iter_ 39700: 9.447899
iter_ 39800: 9.428893
iter_ 39900: 9.406359
iter_ 40000: 9.386926
saved!
sanity check: cost at convergence should be around or below 10
training took 9083 seconds
In [ ]:
 
In [ ]:
 
In [ ]:
 
In [ ]:
 
In [ ]:
 
In [ ]: