#!/usr/bin/env python # coding: utf-8 # In[1]: from keras.models import Sequential from keras.layers import Dense, Embedding, LSTM from keras.utils import np_utils from keras.utils.data_utils import get_file from keras.preprocessing import sequence from keras.preprocessing.text import Tokenizer import numpy as np np.random.seed(13) # In[2]: path = get_file('alice.txt', origin='http://www.gutenberg.org/cache/epub/11/pg11.txt') doc = open(path).readlines()[0:50] tokenizer = Tokenizer() tokenizer.fit_on_texts(doc) doc = tokenizer.texts_to_sequences(doc) doc = [l for l in doc if len(l) > 1] words_size = sum([len(words) - 1 for words in doc]) # In[3]: maxlen = max([len(x)-1 for x in doc]) vocab_size = len(tokenizer.word_index) + 1 # In[4]: def generate_data(X, maxlen, V): for sentence in X: inputs = [] targets = [] for i in range(1, len(sentence)): inputs.append(sentence[0:i]) targets.append(sentence[i]) y = np_utils.to_categorical(targets, V) inputs_sequence = sequence.pad_sequences(inputs, maxlen=maxlen) yield (inputs_sequence, y) # In[5]: def sample(p): p /= sum(p) return np.where(np.random.multinomial(1, p, 1)==1)[1][0] # In[6]: model = Sequential() model.add(Embedding(vocab_size, 128, input_length=maxlen)) model.add(LSTM(128, return_sequences=False)) model.add(Dense(vocab_size, activation='softmax')) model.compile(loss='categorical_crossentropy', optimizer='adadelta') # In[7]: for i in range(30): for x, y in generate_data(doc, maxlen, vocab_size): model.train_on_batch(x, y) in_words = "alice's" for _ in range(maxlen): in_sequence = sequence.pad_sequences(tokenizer.texts_to_sequences([in_words]), maxlen=maxlen) wordid = sample(model.predict(in_sequence)[0]) for k, v in tokenizer.word_index.items(): if v == wordid: in_words += ' ' + k break print(i, in_words) # In[8]: in_words = "alice's" for _ in range(maxlen): in_sequence = sequence.pad_sequences(tokenizer.texts_to_sequences([in_words]), maxlen=maxlen) wordid = model.predict_classes(in_sequence, verbose=0)[0] for k, v in tokenizer.word_index.items(): if v == wordid: in_words += ' ' + k break print(in_words) # In[ ]: