In [1]:
import cPickle as pickle
import msgpack

import numpy as np
In [3]:
# Load vocabulary w/ word frequencies
with open('wmt11.head.vocab', 'rb') as f:
    vocab = msgpack.load(f)
In [4]:
# Load requisite vector data
with open('wmt11.head.vectors', 'rb') as f:
    W = pickle.load(f)
In [7]:
id2word = dict((id, word) for word, (id, _) in vocab.iteritems())
In [10]:
# Normalize word vectors
for i, row in enumerate(W):
    W[i, :] /= np.linalg.norm(row)
    
# Remove context word vectors
W = W[:len(vocab), :]
In [11]:
def most_similar(positive, negative, topn=10, freq_threshold=5):
    # Build a "mean" vector for the given positive and negative terms
    mean_vecs = []
    for word in positive: mean_vecs.append(W[vocab[word][0]])
    for word in negative: mean_vecs.append(-1 * W[vocab[word][0]])
    
    mean = np.array(mean_vecs).mean(axis=0)
    mean /= np.linalg.norm(mean)
    
    # Now calculate cosine distances between this mean vector and all others
    dists = np.dot(W, mean)
    
    best = np.argsort(dists)[::-1][:topn + len(positive) + len(negative) + 100]
    result = [(id2word[i], dists[i]) for i in best if (vocab[id2word[i]] >= freq_threshold
                                                       and id2word[i] not in positive
                                                       and id2word[i] not in negative)]
    return result[:topn]
In [14]:
most_similar(['king', 'woman'], ['man'], topn=50)
Out[14]:
[('queen', 0.69478105985944205),
 ('ace', 0.63016505472136219),
 ('trick', 0.62198680411172658),
 ('library', 0.61596180822198343),
 ('diamond', 0.61546379436428578),
 ('club', 0.60882108049620698),
 ('horse', 0.60577931043391597),
 ('ski', 0.5980682567370863),
 ('tennis', 0.59252997663757134),
 ('chef', 0.58578732345724127),
 ('museum', 0.58238877554666368),
 ('grandmother', 0.58148552464037506),
 ('diamonds', 0.58011253208849856),
 ('crown', 0.57997983286899146),
 ('seller', 0.57651369635738636),
 ('tip', 0.57446540473288965),
 ('oldest', 0.56849683935598727),
 ('holder', 0.56698181344304011),
 ('row', 0.56597681090513596),
 ('Museum', 0.56365025428845172),
 ('royal', 0.56291276425071346),
 ('Royal', 0.56191759370337424),
 ('farmer', 0.55962264699238262),
 ('Queen', 0.55947426321308247),
 ('colony', 0.55792198467607856),
 ('Maine', 0.55782081129452066),
 ('hat', 0.55772124209691432),
 ('dog', 0.5566658071093793),
 ('Valley', 0.55537812887550264),
 ('soccer', 0.55403076872031942),
 ('cinema', 0.55362014730217401),
 ('Latvia', 0.55191882205612497),
 ('hero', 0.55175383201232631),
 ('dancer', 0.55130889459560439),
 ('spade', 0.55042516938080366),
 ('Country', 0.54924256358808599),
 ('Yale', 0.54889249198494516),
 ('Rock', 0.54845215690150428),
 ('girlfriend', 0.54695823638084806),
 ('pool', 0.54691472405799435),
 ('neighbor', 0.54683901446532124),
 ('bars', 0.54670398577022916),
 ('bottle', 0.54548588559461253),
 ('pope', 0.54363205675334925),
 ('boyfriend', 0.54230260301709221),
 ('classic', 0.54154692864168852),
 ('interior', 0.54121481559609896),
 ('Buffalo', 0.54109694727311264),
 ('buyer', 0.5408952311440024),
 ('sheriff', 0.54027326795728747)]
In [15]:
most_similar(['brought', 'seek'], ['bring'], topn=50)
Out[15]:
[('sought', 0.80168320406931981),
 ('seeking', 0.73662888334926047),
 ('forced', 0.69273739435205384),
 ('attempted', 0.68510971171255386),
 ('tried', 0.67516210714164604),
 ('allowed', 0.65480577594618783),
 ('urged', 0.64988576500767947),
 ('managed', 0.64642237086872134),
 ('seeks', 0.64400589953863585),
 ('refused', 0.63527139994723147),
 ('intended', 0.63348152287487691),
 ('unable', 0.62647201702998501),
 ('demanded', 0.62625912225250269),
 ('prompted', 0.62515185408955964),
 ('threatened', 0.62393983356386451),
 ('determined', 0.62224734077632959),
 ('attempting', 0.6181202137691465),
 ('hoped', 0.61761471480150942),
 ('prepared', 0.61306593357078376),
 ('encouraged', 0.61228972301898099),
 ('requested', 0.60900154224998204),
 ('followed', 0.60838821784825281),
 ('helped', 0.60657759212041662),
 ('attempt', 0.60619642784126782),
 ('failed', 0.6045095128678335),
 ('led', 0.60300298435099864),
 ('opted', 0.59973601096877438),
 ('granted', 0.59786114263781998),
 ('initiated', 0.59435889304397749),
 ('chosen', 0.59148716010685476),
 ('faced', 0.58759291122220392),
 ('wanted', 0.58484842439765283),
 ('refusing', 0.58473430252602809),
 ('addressed', 0.58405417697747952),
 ('offered', 0.58351159606145453),
 ('asking', 0.58349246210416927),
 ('rejected', 0.58123114186229263),
 ('decided', 0.58080346748275247),
 ('pledged', 0.57872006649827579),
 ('pressed', 0.57805915462397062),
 ('ordered', 0.57777715792227946),
 ('received', 0.57658184971738702),
 ('designed', 0.57650776001629578),
 ('persuaded', 0.57468989479783739),
 ('urging', 0.57295698093987468),
 ('accepted', 0.57247322897335085),
 ('allowing', 0.57049684140481094),
 ('able', 0.56909602033660478),
 ('calling', 0.56550493868341545),
 ('required', 0.565493188089782)]
In [ ]: