In [1]:
import math
import numpy as np
from pprint import pprint
In [2]:
data = [[0.1, 0.2, 0.3, 0.4, 0.5],
        [0.5, 0.4, 0.3, 0.2, 0.1],
        [0.1, 0.2, 0.3, 0.4, 0.5],
        [0.5, 0.4, 0.3, 0.2, 0.1],
        [0.1, 0.2, 0.3, 0.4, 0.5],
        [0.5, 0.4, 0.3, 0.2, 0.1],
        [0.1, 0.2, 0.3, 0.4, 0.5],
        [0.5, 0.4, 0.3, 0.2, 0.1],
        [0.1, 0.2, 0.3, 0.4, 0.5],
        [0.5, 0.4, 0.3, 0.2, 0.1]]
data = np.array(data)
In [3]:
def greedy_decoder(data):
    return [np.argmax(s) for s in data]
In [4]:
# beam search
def beam_search_decoder(data, k):
    sequences = [[list(), 1.0]]
    for row in data:
        all_candidates = list()
        for i in range(len(sequences)):
            seq, score = sequences[i]
            for j in range(len(row)):
                candidate = [seq + [j], score * -math.log(row[j]+1e-100)]
                all_candidates.append(candidate)
#         pprint(("all candidates: ", all_candidates))
        ordered = sorted(all_candidates, key=lambda tup:tup[1])
#         pprint(("ordered: ", ordered))
        print("="*50)
        sequences = ordered[:k]
        pprint(("sequence: ", sequences))
    return sequences
In [5]:
greedy_decoder(data)
Out[5]:
[4, 0, 4, 0, 4, 0, 4, 0, 4, 0]
In [6]:
beam_search_decoder(data, 3)
==================================================
('sequence: ',
 [[[4], 0.6931471805599453],
  [[3], 0.916290731874155],
  [[2], 1.2039728043259361]])
==================================================
('sequence: ',
 [[[4, 0], 0.4804530139182014],
  [[4, 1], 0.6351243373717793],
  [[3, 0], 0.6351243373717793]])
==================================================
('sequence: ',
 [[[4, 0, 4], 0.33302465198892944],
  [[4, 0, 3], 0.4402346437542523],
  [[4, 1, 4], 0.4402346437542523]])
==================================================
('sequence: ',
 [[[4, 0, 4, 0], 0.23083509858308343],
  [[4, 0, 3, 0], 0.3051474021030719],
  [[4, 1, 4, 0], 0.3051474021030719]])
==================================================
('sequence: ',
 [[[4, 0, 4, 0, 4], 0.1600026977571413],
  [[4, 0, 3, 0, 4], 0.21151206142293622],
  [[4, 1, 4, 0, 4], 0.21151206142293622]])
==================================================
('sequence: ',
 [[[4, 0, 4, 0, 4, 0], 0.11090541883234757],
  [[4, 0, 4, 0, 4, 1], 0.1466089890297302],
  [[4, 0, 3, 0, 4, 0], 0.1466089890297302]])
==================================================
('sequence: ',
 [[[4, 0, 4, 0, 4, 0, 4], 0.07687377837246158],
  [[4, 0, 4, 0, 4, 0, 3], 0.10162160739070145],
  [[4, 0, 4, 0, 4, 1, 4], 0.10162160739070145]])
==================================================
('sequence: ',
 [[[4, 0, 4, 0, 4, 0, 4, 0], 0.05328484273786184],
  [[4, 0, 4, 0, 4, 0, 4, 1], 0.07043873064683441],
  [[4, 0, 4, 0, 4, 0, 3, 0], 0.07043873064683441]])
==================================================
('sequence: ',
 [[[4, 0, 4, 0, 4, 0, 4, 0, 4], 0.03693423851032901],
  [[4, 0, 4, 0, 4, 0, 4, 0, 3], 0.04882440755007468],
  [[4, 0, 4, 0, 4, 0, 4, 1, 4], 0.04882440755007468]])
==================================================
('sequence: ',
 [[[4, 0, 4, 0, 4, 0, 4, 0, 4, 0], 0.025600863289563108],
  [[4, 0, 4, 0, 4, 0, 4, 0, 4, 1], 0.03384250043584397],
  [[4, 0, 4, 0, 4, 0, 4, 0, 3, 0], 0.03384250043584397]])
Out[6]:
[[[4, 0, 4, 0, 4, 0, 4, 0, 4, 0], 0.025600863289563108],
 [[4, 0, 4, 0, 4, 0, 4, 0, 4, 1], 0.03384250043584397],
 [[4, 0, 4, 0, 4, 0, 4, 0, 3, 0], 0.03384250043584397]]
In [ ]: