In [4]:
import numpy as np

In [1]:
wc = {
"I": 3437,
"want": 1215,
"to": 3256,
"eat": 938,
"Chinese": 213,
"food": 1506,
"lunch": 459,
}

In [111]:
rawbigramc = [
[8    , 1087 , 0    , 12   , 0       , 0    , 0]     ,
[3    , 0    , 786  , 0    , 6       , 8    , 6]     ,
[3    , 0    , 10   , 860  , 3       , 0    , 12]    ,
[0    , 0    , 2    , 0    , 19      , 2    , 52]    ,
[2    , 0    , 0    , 0    , 0       , 120  , 1]     ,
[19   , 0    , 17   , 0    , 0       , 0    , 0]     ,
[4    , 0    , 0    , 0    , 0       , 1    , 0]     ,
]

In [117]:
tokenc = 1616

In [92]:
wtype_bic = {
"I x": 95,
"want x": 76,
"to x": 130,
"eat x": 124,
"Chinese x": 20,
"food x": 82,
"lunch x": 45
}


In [114]:
bigramc = [[_+1 for _ in item] for item in rawbigramc]

In [120]:
res = []
for i,w in enumerate("I want to eat Chinese food lunch".split()):
item = [(x/(wc[w] + tokenc)) for x in bigramc[i]]
res.append(item)
prob = np.array(res).reshape(7,7)

In [121]:
for i, w in enumerate(wc.keys()):
print([round(_,3) for _ in list(prob[i] * wc[w])])

[6.122, 740.047, 0.68, 8.842, 0.68, 0.68, 0.68]
[1.717, 0.429, 337.762, 0.429, 3.004, 3.863, 3.004]
[2.673, 0.668, 7.351, 575.414, 2.673, 0.668, 8.688]
[0.367, 0.367, 1.102, 0.367, 7.345, 1.102, 19.465]
[0.349, 0.116, 0.116, 0.116, 0.116, 14.091, 0.233]
[9.648, 0.482, 8.683, 0.482, 0.482, 0.482, 0.482]
[1.106, 0.221, 0.221, 0.221, 0.221, 0.442, 0.221]


## Witten-Bell¶

In [118]:
res = []
for i,w in enumerate("I want to eat Chinese food lunch".split()):
n = wc[w]
t = wtype_bic[w+" x"]
z = tokenc - t
item = [(x/(n+t)) if x > 0 else t/((n+t)*z) for x in rawbigramc[i]]
res.append(item)
prob = np.array(res).reshape(7,7)

In [119]:
for i, w in enumerate(wc.keys()):
print([round(_,3) for _ in list(prob[i] * wc[w])])

[7.785, 1057.763, 0.061, 11.677, 0.061, 0.061, 0.061]
[2.823, 0.046, 739.729, 0.046, 5.647, 7.529, 5.647]
[2.885, 0.084, 9.616, 826.982, 2.885, 0.084, 11.539]
[0.073, 0.073, 1.766, 0.073, 16.782, 1.766, 45.928]
[1.828, 0.011, 0.011, 0.011, 0.011, 109.7, 0.914]
[18.019, 0.051, 16.122, 0.051, 0.051, 0.051, 0.051]
[3.643, 0.026, 0.026, 0.026, 0.026, 0.911, 0.026]

In [ ]:


In [ ]: