#!/usr/bin/env python
import import_ipynb
import random
import numpy as np
from utils.treebank import StanfordSentiment
import matplotlib
matplotlib.use('agg')
import matplotlib.pyplot as plt
import time
from q3_word2vec import *
from q3_sgd import *
# Reset the random seed to make sure that everyone gets the same results
random.seed(314)
dataset = StanfordSentiment()
tokens = dataset.tokens()
nWords = len(tokens)
# We are going to train 10-dimensional vectors for this assignment
dimVectors = 10
# Context size
C = 5
# Reset the random seed to make sure that everyone gets the same results
random.seed(31415)
np.random.seed(9265)
startTime=time.time()
wordVectors = np.concatenate(
((np.random.rand(nWords, dimVectors) - 0.5) /
dimVectors, np.zeros((nWords, dimVectors))),
axis=0)
wordVectors = sgd(
lambda vec: word2vec_sgd_wrapper(skipgram, tokens, vec, dataset, C,
negSamplingCostAndGradient),
wordVectors, 0.3, 40000, None, True, PRINT_EVERY=100)
# Note that normalization is not called here. This is not a bug,
# normalizing during training loses the notion of length.
print("sanity check: cost at convergence should be around or below 10")
print("training took %d seconds" % (time.time() - startTime))
# concatenate the input and output word vectors
# 这里将U,V合并,后面会进行奇异值分解
wordVectors = np.concatenate(
(wordVectors[:nWords,:], wordVectors[nWords:,:]),
axis=0)
visualizeWords = [
"the", "a", "an", ",", ".", "?", "!", "``", "''", "--",
"good", "great", "cool", "brilliant", "wonderful", "well", "amazing",
"worth", "sweet", "enjoyable", "boring", "bad", "waste", "dumb",
"annoying"]
visualizeIdx = [tokens[word] for word in visualizeWords]
visualizeVecs = wordVectors[visualizeIdx, :]
# PCA,采用SVD来实现的,PCA很重要的一点中心化,均值为0
temp = (visualizeVecs - np.mean(visualizeVecs, axis=0))
covariance = 1.0 / len(visualizeIdx) * temp.T.dot(temp)
# SVD的左奇异矩阵恰好就是X.dot(X.T)的特征向量组成的矩阵,而这个矩阵的特征向量恰好就是PCA的主成分
U,S,V = np.linalg.svd(covariance)
coord = temp.dot(U[:,0:2])
for i in range(len(visualizeWords)):
plt.text(coord[i,0], coord[i,1], visualizeWords[i],
bbox=dict(facecolor='green', alpha=0.1))
plt.xlim((np.min(coord[:,0]), np.max(coord[:,0])))
plt.ylim((np.min(coord[:,1]), np.max(coord[:,1])))
plt.savefig('q3_word_vectors.png')
importing Jupyter notebook from q3_word2vec.ipynb importing Jupyter notebook from q1_softmax.ipynb importing Jupyter notebook from q2_gradcheck.ipynb importing Jupyter notebook from q2_sigmoid.ipynb importing Jupyter notebook from q3_sgd.ipynb iter_ 100: 18.604024 iter_ 200: 18.779389 iter_ 300: 18.968864 iter_ 400: 19.209857 iter_ 500: 19.156670 iter_ 600: 19.243421 iter_ 700: 19.440181 iter_ 800: 19.520359 iter_ 900: 19.588859 iter_ 1000: 19.882518 iter_ 1100: 20.161624 iter_ 1200: 20.121685 iter_ 1300: 20.228471 iter_ 1400: 20.474703 iter_ 1500: 20.487733 iter_ 1600: 20.621488 iter_ 1700: 20.671889 iter_ 1800: 20.673450 iter_ 1900: 20.947727 iter_ 2000: 21.009787 iter_ 2100: 21.031895 iter_ 2200: 21.075417 iter_ 2300: 21.037911 iter_ 2400: 21.063337 iter_ 2500: 21.159010 iter_ 2600: 21.220136 iter_ 2700: 21.306704 iter_ 2800: 21.317022 iter_ 2900: 21.313567 iter_ 3000: 21.280537 iter_ 3100: 21.394261 iter_ 3200: 21.222326 iter_ 3300: 21.103933 iter_ 3400: 21.026450 iter_ 3500: 20.940565 iter_ 3600: 20.778982 iter_ 3700: 20.867870 iter_ 3800: 20.790209 iter_ 3900: 20.781813 iter_ 4000: 20.573691 iter_ 4100: 20.434695 iter_ 4200: 20.454322 iter_ 4300: 20.281818 iter_ 4400: 20.213335 iter_ 4500: 20.006197 iter_ 4600: 19.938834 iter_ 4700: 19.816390 iter_ 4800: 19.534211 iter_ 4900: 19.492875 iter_ 5000: 19.327795 saved! iter_ 5100: 19.069202 iter_ 5200: 18.893172 iter_ 5300: 18.771077 iter_ 5400: 18.903853 iter_ 5500: 18.943519 iter_ 5600: 18.835504 iter_ 5700: 18.641866 iter_ 5800: 18.568050 iter_ 5900: 18.406251 iter_ 6000: 18.303576 iter_ 6100: 18.161562 iter_ 6200: 18.011232 iter_ 6300: 17.931368 iter_ 6400: 17.969555 iter_ 6500: 17.838386 iter_ 6600: 17.639902 iter_ 6700: 17.505160 iter_ 6800: 17.334591 iter_ 6900: 17.245499 iter_ 7000: 17.072789 iter_ 7100: 17.000454 iter_ 7200: 16.926262 iter_ 7300: 16.776886 iter_ 7400: 16.724920 iter_ 7500: 16.555643 iter_ 7600: 16.497449 iter_ 7700: 16.367106 iter_ 7800: 16.263621 iter_ 7900: 16.167857 iter_ 8000: 16.075852 iter_ 8100: 15.863419 iter_ 8200: 15.704432 iter_ 8300: 15.516366 iter_ 8400: 15.415402 iter_ 8500: 15.320215 iter_ 8600: 15.224080 iter_ 8700: 15.137174 iter_ 8800: 14.932750 iter_ 8900: 14.888707 iter_ 9000: 14.842695 iter_ 9100: 14.664533 iter_ 9200: 14.700342 iter_ 9300: 14.598106 iter_ 9400: 14.470129 iter_ 9500: 14.346605 iter_ 9600: 14.271724 iter_ 9700: 14.161178 iter_ 9800: 14.101970 iter_ 9900: 14.041289 iter_ 10000: 13.905749 saved! iter_ 10100: 13.781994 iter_ 10200: 13.761556 iter_ 10300: 13.604726 iter_ 10400: 13.544350 iter_ 10500: 13.491717 iter_ 10600: 13.401767 iter_ 10700: 13.305987 iter_ 10800: 13.281872 iter_ 10900: 13.251243 iter_ 11000: 13.149489 iter_ 11100: 13.071243 iter_ 11200: 12.996503 iter_ 11300: 13.014213 iter_ 11400: 12.944917 iter_ 11500: 12.917141 iter_ 11600: 12.835549 iter_ 11700: 12.826709 iter_ 11800: 12.747148 iter_ 11900: 12.641782 iter_ 12000: 12.665167 iter_ 12100: 12.646362 iter_ 12200: 12.637761 iter_ 12300: 12.515719 iter_ 12400: 12.576099 iter_ 12500: 12.521075 iter_ 12600: 12.418966 iter_ 12700: 12.381063 iter_ 12800: 12.381649 iter_ 12900: 12.322723 iter_ 13000: 12.291331 iter_ 13100: 12.231207 iter_ 13200: 12.221519 iter_ 13300: 12.115059 iter_ 13400: 12.065030 iter_ 13500: 12.064713 iter_ 13600: 12.021912 iter_ 13700: 11.971597 iter_ 13800: 11.848290 iter_ 13900: 11.786099 iter_ 14000: 11.744749 iter_ 14100: 11.701911 iter_ 14200: 11.687953 iter_ 14300: 11.626524 iter_ 14400: 11.649125 iter_ 14500: 11.621814 iter_ 14600: 11.579251 iter_ 14700: 11.540343 iter_ 14800: 11.483450 iter_ 14900: 11.383964 iter_ 15000: 11.345420 saved! iter_ 15100: 11.189060 iter_ 15200: 11.210092 iter_ 15300: 11.219047 iter_ 15400: 11.203139 iter_ 15500: 11.158248 iter_ 15600: 11.123628 iter_ 15700: 11.099226 iter_ 15800: 11.064870 iter_ 15900: 11.084253 iter_ 16000: 11.015952 iter_ 16100: 11.004468 iter_ 16200: 11.015049 iter_ 16300: 11.023876 iter_ 16400: 11.010123 iter_ 16500: 10.983821 iter_ 16600: 10.938413 iter_ 16700: 10.894742 iter_ 16800: 10.754350 iter_ 16900: 10.664717 iter_ 17000: 10.623251 iter_ 17100: 10.596035 iter_ 17200: 10.617019 iter_ 17300: 10.721184 iter_ 17400: 10.698315 iter_ 17500: 10.758545 iter_ 17600: 10.730561 iter_ 17700: 10.756100 iter_ 17800: 10.756223 iter_ 17900: 10.729578 iter_ 18000: 10.713750 iter_ 18100: 10.733265 iter_ 18200: 10.717193 iter_ 18300: 10.734548 iter_ 18400: 10.626955 iter_ 18500: 10.573120 iter_ 18600: 10.573155 iter_ 18700: 10.548022 iter_ 18800: 10.497992 iter_ 18900: 10.464389 iter_ 19000: 10.467851 iter_ 19100: 10.451004 iter_ 19200: 10.416570 iter_ 19300: 10.369180 iter_ 19400: 10.380386 iter_ 19500: 10.334510 iter_ 19600: 10.426575 iter_ 19700: 10.402202 iter_ 19800: 10.345839 iter_ 19900: 10.414973 iter_ 20000: 10.414057 saved! iter_ 20100: 10.424682 iter_ 20200: 10.356544 iter_ 20300: 10.423831 iter_ 20400: 10.387770 iter_ 20500: 10.362605 iter_ 20600: 10.376197 iter_ 20700: 10.331289 iter_ 20800: 10.380293 iter_ 20900: 10.337848 iter_ 21000: 10.340521 iter_ 21100: 10.305953 iter_ 21200: 10.317894 iter_ 21300: 10.343518 iter_ 21400: 10.314110 iter_ 21500: 10.271524 iter_ 21600: 10.238044 iter_ 21700: 10.205756 iter_ 21800: 10.176787 iter_ 21900: 10.084072 iter_ 22000: 10.106953 iter_ 22100: 10.053215 iter_ 22200: 10.069332 iter_ 22300: 10.052441 iter_ 22400: 10.012923 iter_ 22500: 10.038104 iter_ 22600: 9.993996 iter_ 22700: 10.004109 iter_ 22800: 10.025915 iter_ 22900: 10.036278 iter_ 23000: 10.014603 iter_ 23100: 9.960675 iter_ 23200: 10.003456 iter_ 23300: 10.055802 iter_ 23400: 9.988262 iter_ 23500: 10.001681 iter_ 23600: 9.990759 iter_ 23700: 9.984735 iter_ 23800: 9.981500 iter_ 23900: 9.894283 iter_ 24000: 9.851653 iter_ 24100: 9.867821 iter_ 24200: 9.866621 iter_ 24300: 9.886243 iter_ 24400: 9.932614 iter_ 24500: 9.943179 iter_ 24600: 9.952576 iter_ 24700: 9.955699 iter_ 24800: 9.905475 iter_ 24900: 9.834542 iter_ 25000: 9.880161 saved! iter_ 25100: 9.860923 iter_ 25200: 9.863371 iter_ 25300: 9.910072 iter_ 25400: 9.902151 iter_ 25500: 9.936475 iter_ 25600: 9.936556 iter_ 25700: 9.952919 iter_ 25800: 9.977249 iter_ 25900: 9.991944 iter_ 26000: 10.043130 iter_ 26100: 10.079640 iter_ 26200: 10.011304 iter_ 26300: 9.939285 iter_ 26400: 9.895720 iter_ 26500: 9.885912 iter_ 26600: 9.877268 iter_ 26700: 9.855755 iter_ 26800: 9.849249 iter_ 26900: 9.819325 iter_ 27000: 9.791931 iter_ 27100: 9.820158 iter_ 27200: 9.768645 iter_ 27300: 9.825911 iter_ 27400: 9.794431 iter_ 27500: 9.857503 iter_ 27600: 9.816648 iter_ 27700: 9.823728 iter_ 27800: 9.821667 iter_ 27900: 9.862476 iter_ 28000: 9.881795 iter_ 28100: 9.883599 iter_ 28200: 9.912319 iter_ 28300: 9.941006 iter_ 28400: 9.917078 iter_ 28500: 9.914267 iter_ 28600: 9.860607 iter_ 28700: 9.905190 iter_ 28800: 9.939964 iter_ 28900: 9.940590 iter_ 29000: 9.893362 iter_ 29100: 9.916780 iter_ 29200: 9.838282 iter_ 29300: 9.834859 iter_ 29400: 9.831548 iter_ 29500: 9.790570 iter_ 29600: 9.788810 iter_ 29700: 9.746504 iter_ 29800: 9.798515 iter_ 29900: 9.782321 iter_ 30000: 9.704538 saved! iter_ 30100: 9.729074 iter_ 30200: 9.768284 iter_ 30300: 9.784937 iter_ 30400: 9.780768 iter_ 30500: 9.828559 iter_ 30600: 9.873279 iter_ 30700: 9.866925 iter_ 30800: 9.874105 iter_ 30900: 9.875424 iter_ 31000: 9.853438 iter_ 31100: 9.834655 iter_ 31200: 9.827070 iter_ 31300: 9.808807 iter_ 31400: 9.763422 iter_ 31500: 9.808970 iter_ 31600: 9.875959 iter_ 31700: 9.873590 iter_ 31800: 9.918868 iter_ 31900: 9.970307 iter_ 32000: 10.021178 iter_ 32100: 9.993948 iter_ 32200: 9.943728 iter_ 32300: 9.862252 iter_ 32400: 9.772822 iter_ 32500: 9.730033 iter_ 32600: 9.708618 iter_ 32700: 9.697634 iter_ 32800: 9.719430 iter_ 32900: 9.686365 iter_ 33000: 9.641960 iter_ 33100: 9.700967 iter_ 33200: 9.686835 iter_ 33300: 9.655312 iter_ 33400: 9.677657 iter_ 33500: 9.650491 iter_ 33600: 9.666483 iter_ 33700: 9.671914 iter_ 33800: 9.658939 iter_ 33900: 9.579504 iter_ 34000: 9.546839 iter_ 34100: 9.506833 iter_ 34200: 9.497960 iter_ 34300: 9.504882 iter_ 34400: 9.526258 iter_ 34500: 9.543517 iter_ 34600: 9.579199 iter_ 34700: 9.570666 iter_ 34800: 9.567154 iter_ 34900: 9.540781 iter_ 35000: 9.605205 saved! iter_ 35100: 9.664605 iter_ 35200: 9.681135 iter_ 35300: 9.614818 iter_ 35400: 9.566828 iter_ 35500: 9.610944 iter_ 35600: 9.663755 iter_ 35700: 9.722588 iter_ 35800: 9.722484 iter_ 35900: 9.677905 iter_ 36000: 9.681488 iter_ 36100: 9.697359 iter_ 36200: 9.675975 iter_ 36300: 9.632207 iter_ 36400: 9.572983 iter_ 36500: 9.546814 iter_ 36600: 9.563374 iter_ 36700: 9.569601 iter_ 36800: 9.610571 iter_ 36900: 9.583856 iter_ 37000: 9.579905 iter_ 37100: 9.555246 iter_ 37200: 9.574555 iter_ 37300: 9.529198 iter_ 37400: 9.501708 iter_ 37500: 9.510385 iter_ 37600: 9.576810 iter_ 37700: 9.520415 iter_ 37800: 9.561922 iter_ 37900: 9.574543 iter_ 38000: 9.605944 iter_ 38100: 9.620448 iter_ 38200: 9.662221 iter_ 38300: 9.625025 iter_ 38400: 9.581447 iter_ 38500: 9.615949 iter_ 38600: 9.606902 iter_ 38700: 9.663988 iter_ 38800: 9.608830 iter_ 38900: 9.631785 iter_ 39000: 9.638900 iter_ 39100: 9.590407 iter_ 39200: 9.596133 iter_ 39300: 9.532854 iter_ 39400: 9.522097 iter_ 39500: 9.497443 iter_ 39600: 9.458520 iter_ 39700: 9.447899 iter_ 39800: 9.428893 iter_ 39900: 9.406359 iter_ 40000: 9.386926 saved! sanity check: cost at convergence should be around or below 10 training took 9083 seconds