Neural network text prediction with textgenrnn

By Allison Parrish

This is a quick tutorial on how to use Max Woolf's textgenrnn to get started generating text with recurrent neural networks.

Like a Markov chain, a recurrent neural network (RNN) is a way to make predictions about what will come next in a sequence. For our purposes, the sequence in question is a sequence of characters, and the prediction we want to make is which character will come next. Both Markov models and recurrent neural networks do this by using statistical properties of text to make a probability distribution for what character will come next, given some information about what comes before. The two procedures work very differently internally, and we're not going to go into the gory details about implementation here. (But if you're interested in the gory details, here's a good place to start.) For our purposes, the main functional difference between a Markov chain and a recurrent neural network is the portion of the sequence used to make the prediction. A Markov model uses a fixed window of history from the sequence, while an RNN (theoretically) uses the entire history of the sequence.

Start with Markov

To illustrate, let's look at the word "condescendences." In a Markov model based on bigrams from this string of characters, you'd make a list of bigrams and the characters that follow those bigrams, like so:

n-grams next?
co n
on d
nd e, e
de s, n
es c, (end of text)
sc e
ce n, s
en d, c
nc e

You could also write this as a probability distribution, with one column for each bigram. The value in each cell indicates the probability that the character following the bigram in a given row will be followed by the character in a given column:

n-grams c o n d e s END
co 0 0 1.0 0 0 0 0
on 0 0 0 1.0 0 0 0
nd 0 0 0 0 1.0 0 0
de 0 0 0.5 0 0 0.5 0
es 0.5 0 0 0 0 0 0.5
sc 0 0 0 0 1.0 0 0
ce 0 0 0.5 0 0 0.5 0
en 0.5 0 0 0.5 0 0 0
nc 0 0 0 0 1.0 0 0

Each row of this table is a probability distribution, meaning that it shows how probable a given letter is to follow the n-gram in the original text. In a probability distribution, all of the values add up to 1.

Fitting a Markov model to the data is a matter of looking at each sequence of characters in a given text, and updating the table of probability distributions accordingly. To make a prediction from this table, you can "sample" from the probability distribution for a given n-gram (i.e., sampling from the distribution for the bigram de, you'd have a 50% chance of picking n and a 50% chance of picking s).

Another way of thinking about this Markov model is as a (hypothetical!) function f that takes a bigram as a parameter and returns a probability distribution for that bigram:

f("ce") → [0.0, 0.0, 0.5, 0.0, 0.0, 0.5, 0.0]

(Note that the values at each index in this distribution line up with the columns in the table above.)

The items in the list returned from this function correspond to the probability for the corresponding next character, as given in the table. To sample from this list, you'd pick randomly among the indices according to their probabilities, and then look up the corresponding character by its position in the table.

To generate new text from this model:

  1. Set your output string to a randomly selected n-gram
  2. Sample a letter from the probability distribution associated with the n-gram at the end of the output string
  3. Append the sampled letter to the end of the string
  4. Repeat from (2) until the END token is reached

Of course, you don't write this function by hand! When you're creating a Markov model from your data (or "training" the model), you're essentially asking the computer to write this function for you. In this sense, a Markov model is a very simple kind of machine learning, since the computer "learns" the probability distribution from the data that you feed it.

A (very) simplified explanation of RNNs

The mechanism by which a recurrent neural network "learns" probability distributions from sequences is much more sophisticated than the mechanism used in a Markov model, but functionally they're very similar: you give the computer some data to "train" on, and then ask it to automatically create a function that will return a probability distribution of what comes next, given some input. An RNN differs from a Markov chain in that to predict the next item in the sequence, you pass in the entire sequence instead of just the most recent n-gram.

In other words, you can (again, hypothetically) think of an RNN as a way of automatically creating a function f that takes a sequence of characters of arbitrary length and returns a probability distribution for which character comes next in the sequence. Unlike a Markov chain, it's possible to improve the accuracy of the probability distribution returned from this function by training on the same data multiple times.

Let's say that we want to train the RNN on the string "condescendences" to learn this function, and we want to make a prediction about what character is most likely to follow the sequence of characters "condescendence". When training a neural network, the process of learning a function like this works iteratively: you start off with a function that essentially gives a uniform probability distribution for each outcome (i.e., no one outcome is considered more likely than any other):

f("condescendence") → [0.14, 0.14, 0.14, 0.14, 0.14, 0.14, 0.14] (after zero passes through the data)

... and as you iterate over the training data (in this case, the word "condescendences"), the probability distribution gradually improves, ideally until it comes to accurately reflect the actual observed distribution (in the parlance, until it "converges"). After some number of passes through the data, you might expect the automatically-learned function to return distributions like this:

f("condescendence") → [0.01, 0.02, 0.01, 0.03, 0.01, 0.9, 0.02] (after n passes through the data)

A single pass through the training data is called an "epoch." When it comes to any neural network, and RNNs in particular, more epochs is almost always better.

To generate text from this model:

  1. Initialize your output string to an empty string, or a random character, or a starting "prefix" that you specify;
  2. Sample the next letter from the distribution returned for the current output string;
  3. Append that character to the end of the output string;
  4. Repeat from (2)

Of course, in a real life application of both a Markov model and an RNN, you'd normally have more than seven items in the probability distribution! In fact, you'd have one element in the probability distribution for every possible character that occurs in the text. (Meaning that if there were 100 unique characters found in the text, the probability distribution would have 100 items in it.)

Markov chains vs RNNs

The primary benefit of an RNN over a Markov model for text generation is that an RNN takes into account the entire history of a sequence when generating the next character. This means that, for example, an RNN can theoretically learn how to close quotes and parentheses, which a Markov chain will never be able to reliably do (at least for pairs of quotes and parentheses longer than the n-gram of the Markov chain).

The drawback of RNNs is that they are computationally expensive, from both a processing and memory perspective. This is (again) a simplification, but internally, RNNs work by "squishing" information about the training data down into large matrices, and make predictions by performing calculations on these large matrices. That means that you need a lot of CPU and RAM to train an RNN, and the resulting models (when stored to disk) can be very large. Training an RNN also (usually) takes a lot of time.

Another consideration is the size of your corpus. Markov models will give interesting and useful results even for very small datasets, but RNNs require large amounts of data to train—the more data the better.

So what do you do if you don't have a very large corpus? Or if you don't have a lot of time to train on your corpus?

RNN generation from pre-trained models

Fortunately for us, developer and data scientist Max Woolf has made a Python library called textgenrnn that makes it really easy to experiment with RNN text generation. This library includes a model (according to the documentation) "trained on hundreds of thousands of text documents, from Reddit submissions (via BigQuery) and Facebook Pages (via my Facebook Page Post Scraper), from a very diverse variety of subreddits/Pages," and allows you to use this model as a starting point for your own training.

To install textgenrnn, you'll probably want to install Keras first. With Anaconda:

In [ ]:
!conda install -y keras

Then install textgenrnn with pip:

In [ ]:
!pip install textgenrnn

Once it's installed, import the textgenrnn class from the package:

In [1]:
from textgenrnn import textgenrnn
Using TensorFlow backend.

And create a new textgenrnn object like so:

In [2]:
textgen = textgenrnn()

This object has a .generate() method which will, by default, generate text from the pre-trained model only.

In [3]:
textgen.generate()
Looking to the stream and a start in the first time to see the strong of the success and a president and the star to the problems in the street to the performance is a statement of the most life in 

To train it on your own text, use the .train_on_texts() method, passing in a list of strings. The num_epochs parameter allows you to indicate how many epochs (i.e., passes over the data) should be performed. The more epochs the better, especially for shorter texts, but you'll get okay results even with just a few. For The Road Not Taken, twenty epochs worked well for me:

In [10]:
textgen.train_on_texts(open("frost.txt").readlines(), num_epochs=20)
Epoch 1/20
754/754 [==============================] - 1s 2ms/step - loss: 0.9443
Epoch 2/20
754/754 [==============================] - 1s 2ms/step - loss: 0.8913
Epoch 3/20
754/754 [==============================] - 1s 2ms/step - loss: 0.8354
Epoch 4/20
754/754 [==============================] - 1s 2ms/step - loss: 0.7948
Epoch 5/20
754/754 [==============================] - 1s 2ms/step - loss: 0.7595
Epoch 6/20
754/754 [==============================] - 2s 2ms/step - loss: 0.7419
Epoch 7/20
754/754 [==============================] - 2s 2ms/step - loss: 0.7108
Epoch 8/20
754/754 [==============================] - 2s 2ms/step - loss: 0.6971
Epoch 9/20
754/754 [==============================] - 2s 2ms/step - loss: 0.6671
Epoch 10/20
754/754 [==============================] - 2s 2ms/step - loss: 0.6484
Epoch 11/20
754/754 [==============================] - 2s 2ms/step - loss: 0.6317
Epoch 12/20
754/754 [==============================] - 2s 2ms/step - loss: 0.6174
Epoch 13/20
754/754 [==============================] - 2s 2ms/step - loss: 0.6048
Epoch 14/20
754/754 [==============================] - 2s 3ms/step - loss: 0.5942
Epoch 15/20
754/754 [==============================] - 2s 2ms/step - loss: 0.5849
Epoch 16/20
754/754 [==============================] - 2s 2ms/step - loss: 0.5769
Epoch 17/20
754/754 [==============================] - 2s 2ms/step - loss: 0.5711
Epoch 18/20
754/754 [==============================] - 2s 2ms/step - loss: 0.5641
Epoch 19/20
754/754 [==============================] - 2s 2ms/step - loss: 0.5597
Epoch 20/20
754/754 [==============================] - 2s 2ms/step - loss: 0.5564

After training, you can generate new text using the .generate() method again:

In [11]:
textgen.generate()
And both and as down that the first ages and be that the first that the first travelled one as far as I could ages and be about the wood,


The results aren't very interesting because by default the generator is very conservative in how it samples from the probability distribution. You can use the temperature parameter to make the sampling a bit more likely to pick improbable outcomes. The higher the value, the weirder the results. The default is 0.2, and going above 1.0 is likely to produce unacceptably strange results:

In [13]:
textgen.generate(temperature=0.5)
In leaves sorry I could sood sorry I having the poor as I had to down the ages in the wive as the diverged in the one traveler, and I stood


In [14]:
textgen.generate(temperature=0.9)
Tow that the one one at morning there in a yell travels to what to worre.

If you pass a number n to the .generate() method as its first parameter, and the parameter return_as_list=True, .generate() will return a list of n instances of text generation from the model:

In [15]:
poem = textgen.generate(10, temperature=0.8, return_as_list=True)
for line in poem:
    print(line.strip())
And as way that other to screase.
Because other good on the feel about the first doad that stack,
And having the undergrowthy and wanted sororing about the roads one way to one eneust the first and equally bectro equally leaves in as another ages and as fair,
And ages in the brothes as mance as for the roads in the one least way tallmorth in for a black.
I shall be some way leads both diverged for, but worn in the one less travelled there
And be in the condence, and I stood
To kent the passing there

In and be ages travel both the grasside, and I haved be soober,
Then diverged in an iqually least less llastread diverged in a I had a longe as fair,

(This may take a little while.)

I've found that textgenrnn works especially well with very short, word-length texts. For example, download this file of human moods from Corpora Project, and put it in the same directory as this notebook. The textgenrnn library stores its models globally, so you'll first need to reset the library to its initial state:

In [23]:
textgen.reset()

Then load the JSON file and grab just the list of words naming moods:

In [24]:
import json
mood_data = json.loads(open("./moods.json").read())
moods = mood_data['moods']

Now, train the RNN on these moods. One epoch will do the trick:

In [25]:
textgen.train_on_texts(moods, num_epochs=1)
Epoch 1/1
6651/6651 [==============================] - 13s 2ms/step - loss: 2.1288

Now generate a list of new moods:

In [28]:
new_moods = textgen.generate(25, temperature=0.8, return_as_list=True)

And print them out!

In [30]:
print('\n'.join(new_moods))
terered
distructive
formersaked form
usedious
worship
lose
day
experientacled
sircial
influet
indeelled
buld
phoneed
ounded
guy
announced
orchosies
time
boardy
reless
includive
sometifent
upcomist
parablever
suffered

I don't know about you, but I'm feeling a little sometifent today.

Further reading