In [1]:
import cPickle as pickle
import msgpack

import numpy as np

In [3]:
# Load vocabulary w/ word frequencies
with open('wmt11.head.vocab', 'rb') as f:

In [4]:
# Load requisite vector data
with open('wmt11.head.vectors', 'rb') as f:

In [7]:
id2word = dict((id, word) for word, (id, _) in vocab.iteritems())

In [10]:
# Normalize word vectors
for i, row in enumerate(W):
W[i, :] /= np.linalg.norm(row)

# Remove context word vectors
W = W[:len(vocab), :]

In [11]:
def most_similar(positive, negative, topn=10, freq_threshold=5):
# Build a "mean" vector for the given positive and negative terms
mean_vecs = []
for word in positive: mean_vecs.append(W[vocab[word][0]])
for word in negative: mean_vecs.append(-1 * W[vocab[word][0]])

mean = np.array(mean_vecs).mean(axis=0)
mean /= np.linalg.norm(mean)

# Now calculate cosine distances between this mean vector and all others
dists = np.dot(W, mean)

best = np.argsort(dists)[::-1][:topn + len(positive) + len(negative) + 100]
result = [(id2word[i], dists[i]) for i in best if (vocab[id2word[i]] >= freq_threshold
and id2word[i] not in positive
and id2word[i] not in negative)]
return result[:topn]

In [14]:
most_similar(['king', 'woman'], ['man'], topn=50)

Out[14]:
[('queen', 0.69478105985944205),
('ace', 0.63016505472136219),
('trick', 0.62198680411172658),
('library', 0.61596180822198343),
('diamond', 0.61546379436428578),
('club', 0.60882108049620698),
('horse', 0.60577931043391597),
('ski', 0.5980682567370863),
('tennis', 0.59252997663757134),
('chef', 0.58578732345724127),
('museum', 0.58238877554666368),
('grandmother', 0.58148552464037506),
('diamonds', 0.58011253208849856),
('crown', 0.57997983286899146),
('seller', 0.57651369635738636),
('tip', 0.57446540473288965),
('oldest', 0.56849683935598727),
('holder', 0.56698181344304011),
('row', 0.56597681090513596),
('Museum', 0.56365025428845172),
('royal', 0.56291276425071346),
('Royal', 0.56191759370337424),
('farmer', 0.55962264699238262),
('Queen', 0.55947426321308247),
('colony', 0.55792198467607856),
('Maine', 0.55782081129452066),
('hat', 0.55772124209691432),
('dog', 0.5566658071093793),
('Valley', 0.55537812887550264),
('soccer', 0.55403076872031942),
('cinema', 0.55362014730217401),
('Latvia', 0.55191882205612497),
('hero', 0.55175383201232631),
('dancer', 0.55130889459560439),
('Country', 0.54924256358808599),
('Yale', 0.54889249198494516),
('Rock', 0.54845215690150428),
('girlfriend', 0.54695823638084806),
('pool', 0.54691472405799435),
('neighbor', 0.54683901446532124),
('bars', 0.54670398577022916),
('bottle', 0.54548588559461253),
('pope', 0.54363205675334925),
('boyfriend', 0.54230260301709221),
('classic', 0.54154692864168852),
('interior', 0.54121481559609896),
('Buffalo', 0.54109694727311264),
('sheriff', 0.54027326795728747)]
In [15]:
most_similar(['brought', 'seek'], ['bring'], topn=50)

Out[15]:
[('sought', 0.80168320406931981),
('seeking', 0.73662888334926047),
('forced', 0.69273739435205384),
('attempted', 0.68510971171255386),
('tried', 0.67516210714164604),
('allowed', 0.65480577594618783),
('urged', 0.64988576500767947),
('managed', 0.64642237086872134),
('seeks', 0.64400589953863585),
('refused', 0.63527139994723147),
('intended', 0.63348152287487691),
('unable', 0.62647201702998501),
('demanded', 0.62625912225250269),
('prompted', 0.62515185408955964),
('threatened', 0.62393983356386451),
('determined', 0.62224734077632959),
('attempting', 0.6181202137691465),
('hoped', 0.61761471480150942),
('prepared', 0.61306593357078376),
('encouraged', 0.61228972301898099),
('requested', 0.60900154224998204),
('followed', 0.60838821784825281),
('helped', 0.60657759212041662),
('attempt', 0.60619642784126782),
('failed', 0.6045095128678335),
('led', 0.60300298435099864),
('opted', 0.59973601096877438),
('granted', 0.59786114263781998),
('initiated', 0.59435889304397749),
('chosen', 0.59148716010685476),
('faced', 0.58759291122220392),
('wanted', 0.58484842439765283),
('refusing', 0.58473430252602809),
('offered', 0.58351159606145453),
('rejected', 0.58123114186229263),
('decided', 0.58080346748275247),
('pledged', 0.57872006649827579),
('pressed', 0.57805915462397062),
('ordered', 0.57777715792227946),
('required', 0.565493188089782)]