In [1]:
import math
import numpy as np
from pprint import pprint

In [2]:
data = [[0.1, 0.2, 0.3, 0.4, 0.5],
[0.5, 0.4, 0.3, 0.2, 0.1],
[0.1, 0.2, 0.3, 0.4, 0.5],
[0.5, 0.4, 0.3, 0.2, 0.1],
[0.1, 0.2, 0.3, 0.4, 0.5],
[0.5, 0.4, 0.3, 0.2, 0.1],
[0.1, 0.2, 0.3, 0.4, 0.5],
[0.5, 0.4, 0.3, 0.2, 0.1],
[0.1, 0.2, 0.3, 0.4, 0.5],
[0.5, 0.4, 0.3, 0.2, 0.1]]
data = np.array(data)

In [3]:
def greedy_decoder(data):
return [np.argmax(s) for s in data]

In [4]:
# beam search
def beam_search_decoder(data, k):
sequences = [[list(), 1.0]]
for row in data:
all_candidates = list()
for i in range(len(sequences)):
seq, score = sequences[i]
for j in range(len(row)):
candidate = [seq + [j], score * -math.log(row[j]+1e-100)]
all_candidates.append(candidate)
#         pprint(("all candidates: ", all_candidates))
ordered = sorted(all_candidates, key=lambda tup:tup[1])
#         pprint(("ordered: ", ordered))
print("="*50)
sequences = ordered[:k]
pprint(("sequence: ", sequences))
return sequences

In [5]:
greedy_decoder(data)

Out[5]:
[4, 0, 4, 0, 4, 0, 4, 0, 4, 0]
In [6]:
beam_search_decoder(data, 3)

==================================================
('sequence: ',
[[[4], 0.6931471805599453],
[[3], 0.916290731874155],
[[2], 1.2039728043259361]])
==================================================
('sequence: ',
[[[4, 0], 0.4804530139182014],
[[4, 1], 0.6351243373717793],
[[3, 0], 0.6351243373717793]])
==================================================
('sequence: ',
[[[4, 0, 4], 0.33302465198892944],
[[4, 0, 3], 0.4402346437542523],
[[4, 1, 4], 0.4402346437542523]])
==================================================
('sequence: ',
[[[4, 0, 4, 0], 0.23083509858308343],
[[4, 0, 3, 0], 0.3051474021030719],
[[4, 1, 4, 0], 0.3051474021030719]])
==================================================
('sequence: ',
[[[4, 0, 4, 0, 4], 0.1600026977571413],
[[4, 0, 3, 0, 4], 0.21151206142293622],
[[4, 1, 4, 0, 4], 0.21151206142293622]])
==================================================
('sequence: ',
[[[4, 0, 4, 0, 4, 0], 0.11090541883234757],
[[4, 0, 4, 0, 4, 1], 0.1466089890297302],
[[4, 0, 3, 0, 4, 0], 0.1466089890297302]])
==================================================
('sequence: ',
[[[4, 0, 4, 0, 4, 0, 4], 0.07687377837246158],
[[4, 0, 4, 0, 4, 0, 3], 0.10162160739070145],
[[4, 0, 4, 0, 4, 1, 4], 0.10162160739070145]])
==================================================
('sequence: ',
[[[4, 0, 4, 0, 4, 0, 4, 0], 0.05328484273786184],
[[4, 0, 4, 0, 4, 0, 4, 1], 0.07043873064683441],
[[4, 0, 4, 0, 4, 0, 3, 0], 0.07043873064683441]])
==================================================
('sequence: ',
[[[4, 0, 4, 0, 4, 0, 4, 0, 4], 0.03693423851032901],
[[4, 0, 4, 0, 4, 0, 4, 0, 3], 0.04882440755007468],
[[4, 0, 4, 0, 4, 0, 4, 1, 4], 0.04882440755007468]])
==================================================
('sequence: ',
[[[4, 0, 4, 0, 4, 0, 4, 0, 4, 0], 0.025600863289563108],
[[4, 0, 4, 0, 4, 0, 4, 0, 4, 1], 0.03384250043584397],
[[4, 0, 4, 0, 4, 0, 4, 0, 3, 0], 0.03384250043584397]])

Out[6]:
[[[4, 0, 4, 0, 4, 0, 4, 0, 4, 0], 0.025600863289563108],
[[4, 0, 4, 0, 4, 0, 4, 0, 4, 1], 0.03384250043584397],
[[4, 0, 4, 0, 4, 0, 4, 0, 3, 0], 0.03384250043584397]]
In [ ]: