In [1]:
import os
import collections
import operator
import torch
import numpy as np
In [2]:
MODEL_ROOT = os.path.join(os.environ["HOME"], "Downloads")
MODEL_ID = "f140225004"
In [3]:
def grab_mod(sd, path):
    return {
        k.replace(path, "", 1): v
        for k, v in sd.items()
        if k.startswith(path)
    }
In [4]:
class ByteLSTM(torch.nn.Module):
    def __init__(
        self,
        embedding_width,
        lstm_width,
        lstm_depth,
        mlp_widths,
    ):
        super(ByteLSTM, self).__init__()
        self.embedding = torch.nn.Embedding(256, embedding_width)
        self.lstm = torch.nn.LSTM(
            input_size=embedding_width,
            hidden_size=lstm_width,
            num_layers=lstm_depth,
            batch_first=True,
            bidirectional=True,
        )
        mlp_layers = []
        for in_width, out_width in zip([2 * lstm_width] + mlp_widths, mlp_widths):
            if mlp_layers:
                mlp_layers.append(torch.nn.ReLU())
            mlp_layers.append(torch.nn.Linear(in_width, out_width))
        self.mlp = torch.nn.Sequential(*mlp_layers)
        
        self.lstm_width = lstm_width
        self.lstm_depth = lstm_depth
        vocab = []

    def forward(self, byte_input):
        token_emb = self.embedding(byte_input.long())
        empty = torch.zeros(self.lstm_depth * 2, token_emb.size(0), self.lstm_width)
        rep, ns = self.lstm(token_emb, (empty, empty))
        pooled = torch.sum(rep, 1) / token_emb.shape[1]
        raw_scores = self.mlp(pooled)
        normalized = torch.nn.functional.softmax(raw_scores, dim=1)
        return normalized

    @torch.jit.export
    def get_classes(self):
        return self.vocab

    def run_on_text(self, text, limit=3, mod=None):
        mod = mod if mod is not None else self
        text_bytes = text.lower().encode("utf-8")
        text_tensor = torch.as_tensor(np.ndarray(shape=(1, len(text_bytes)), dtype=np.byte, buffer=text_bytes))
        scores = mod(text_tensor)
        pairs = zip(self.get_classes(), scores[0].detach().numpy())
        return collections.OrderedDict(sorted(pairs, key=operator.itemgetter(1), reverse=True)[:limit])
In [5]:
model = ByteLSTM(64, 128, 1, [64, 16])
model.eval()
Out[5]:
ByteLSTM(
  (embedding): Embedding(256, 64)
  (lstm): LSTM(64, 128, batch_first=True, bidirectional=True)
  (mlp): Sequential(
    (0): Linear(in_features=256, out_features=64, bias=True)
    (1): ReLU()
    (2): Linear(in_features=64, out_features=16, bias=True)
  )
)
In [6]:
train_output = torch.load(f"{MODEL_ROOT}/model-{MODEL_ID}.pt", map_location="cpu")
WARNING:root:This caffe2 python run does not have GPU support. Will run in CPU only mode.
Install apex from https://github.com/NVIDIA/apex/.
In [7]:
model.embedding.load_state_dict(grab_mod(train_output["model_state"], "embedding.word_embedding."))
model.lstm.load_state_dict(grab_mod(train_output["model_state"], "representation.lstm.lstm."))
model.mlp[0].load_state_dict(grab_mod(train_output["model_state"], "decoder.mlp.0."))
model.mlp[2].load_state_dict(grab_mod(train_output["model_state"], "decoder.mlp.2."))
model.vocab = list(train_output["tensorizers"]["labels"].vocab)
None
In [8]:
model.run_on_text("lebron")
Out[8]:
OrderedDict([('nba', 0.8307514),
             ('nfl', 0.091046356),
             ('fantasyfootball', 0.035300486)])
In [9]:
smod = torch.jit.script(model)
smod.eval()
Out[9]:
ScriptModule(
  original_name=ByteLSTM
  (embedding): ScriptModule(original_name=Embedding)
  (lstm): ScriptModule(original_name=LSTM)
  (mlp): _ConstSequential(
    original_name=_ConstSequential
    (0): ScriptModule(original_name=Linear)
    (1): ScriptModule(original_name=ReLU)
    (2): ScriptModule(original_name=Linear)
  )
)
In [10]:
smod.run_on_text("lebron")
Out[10]:
OrderedDict([('nba', 0.8307514),
             ('nfl', 0.091046356),
             ('fantasyfootball', 0.035300486)])
In [11]:
smod.save("model-reddit16-{MODEL_ID}.pt1")
In [12]:
loaded = torch.jit.load("model-reddit16-{MODEL_ID}.pt1")
loaded.eval()
Out[12]:
ScriptModule(
  original_name=ByteLSTM
  (embedding): ScriptModule(original_name=Embedding)
  (lstm): ScriptModule(original_name=LSTM)
  (mlp): ScriptModule(
    original_name=_ConstSequential
    (0): ScriptModule(original_name=Linear)
    (1): ScriptModule(original_name=ReLU)
    (2): ScriptModule(original_name=Linear)
  )
)
In [13]:
model.run_on_text("lebron", mod=loaded)
Out[13]:
OrderedDict([('nba', 0.8307514),
             ('nfl', 0.091046356),
             ('fantasyfootball', 0.035300486)])
In [14]:
model.run_on_text("Vader is the worst in the game", mod=loaded)
Out[14]:
OrderedDict([('StarWarsBattlefront', 0.73975956),
             ('gaming', 0.22228362),
             ('Overwatch', 0.019848222)])
In [15]:
model.run_on_text("If he hadn't been in the jungle, it would have been an easy win", mod=loaded)
Out[15]:
OrderedDict([('leagueoflegends', 0.9999074),
             ('SquaredCircle', 6.828745e-05),
             ('nba', 1.3255787e-05)])
In [16]:
loaded.get_classes()
Out[16]:
['news',
 'nba',
 'CFB',
 'StarWarsBattlefront',
 'Overwatch',
 'fantasyfootball',
 'soccer',
 'todayilearned',
 'The_Donald',
 'leagueoflegends',
 'hockey',
 'nfl',
 'SquaredCircle',
 'politics',
 'worldnews',
 'gaming']