I am planning to use the DL4J Doc2Vec implementation for a sentiment analysis.
However, I dont want to start with an empty network but the staring point should be a pre-trained network: The initial trining should be done with the Sentiment140 dataset which can be found at https://www.kaggle.com/kazanova/sentiment140. It contains 1,600,000 tweets extracted using the twitter api.
In this Workbook I describe how to train and save a DL4J Doc2Vec.
We install the following libraries with the help of maven
%classpath config resolver maven-public http://192.168.1.10:8081/repository/maven-public/
%%classpath add mvn
org.deeplearning4j:deeplearning4j-nlp:1.0.0-beta2
org.deeplearning4j:deeplearning4j-core:1.0.0-beta2
org.nd4j:nd4j-native-platform:1.0.0-beta2
com.github.habernal:confusion-matrix:1.0
Added new repo: maven-public
and we import all relevant classes
import org.deeplearning4j.text.documentiterator.LabelledDocument
import org.deeplearning4j.text.documentiterator.SimpleLabelAwareIterator
import org.deeplearning4j.text.tokenization.tokenizerfactory._
import org.deeplearning4j.text.tokenization.tokenizer.TokenPreProcess
import org.deeplearning4j.text.tokenization.tokenizer.preprocessor._
import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer
import org.deeplearning4j.models.paragraphvectors.ParagraphVectors
import java.io._
import scala.io.Source
import scala.util.Random
import com.github.habernal.confusionmatrix.ConfusionMatrix
import org.deeplearning4j.text.documentiterator.LabelledDocument import org.deeplearning4j.text.documentiterator.SimpleLabelAwareIterator import org.deeplearning4j.text.tokenization.tokenizerfactory._ import org.deeplearning4j.text.tokenization.tokenizer.TokenPreProcess import org.deeplearning4j.text.tokenization.tokenizer.preprocessor._ import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer import org.deeplearning4j.models.paragraphvectors.ParagraphVectors import java.io._ import scala.io.Source import scala.util.Random import com.github.habernal.confusionmatrix.ConfusionMatrix
We will use LabelledDocument objects to feed our network.
The sentiment in the Sentiment140 file is using the following values
We create a new LabelledTwitterDocument class which extends LabelledDocument with the gloal to be able to create new LabelledDocument objects with a simple constructor. We also want to use the correct text labels so we add some label translation logic
class LabelledTwitterDocument(txt:String, lbl:String) extends LabelledDocument {
this.setContent(txt)
setupLabel(lbl)
// 0=negative 2=neutral 4=positive
def setupLabel(number:String) = {
var v = number.trim()
if (v=="\"0\"") v = "negative"
else if (v=="\"2\"") v = "neutral"
else if (v=="\"4\"") v = "positive"
this.addLabel(v)
}
}
defined class LabelledTwitterDocument
Usually I load the data directly from the internet. In this case however the amout of data is too big and it is much more effient to work with a local file.
var labeledDocumentList = Source.fromFile("training.1600000.processed.noemoticon.csv","ISO-8859-1")
.getLines()
.map(str => str.split(","))
.map(sa => new LabelledTwitterDocument(sa(5),sa(0)))
.toList
println(labeledDocumentList.last)
labeledDocumentList.size
LabelledDocument(id=null, content="happy #charitytuesday @theNSPCC @SparksCharity @SpeakingUpH4H ", labels=[positive])
1600000
import scala.collection.JavaConversions._
labeledDocumentList = Random.shuffle(labeledDocumentList)
val split = (0.95 * labeledDocumentList.length).toInt
val trainingData = labeledDocumentList.slice(0, split)
val testData = labeledDocumentList.slice(split, labeledDocumentList.length)
testData.length
80000
We load the stopwords: Here we can use Source.fromFile without issues. DL4J needs java collections so we make sure that we convert the result into a java list.
import scala.collection.JavaConversions._
val stopWords:java.util.List[String] = Source.fromFile("stopwords.txt").getLines.toList
[a, about, above, after, again, against, all, am, an, and, any, are, aren't, as, at, be, because, been, before, being, below, between, both, but, by, can't, cannot, could, couldn't, did, didn't, do, does, doesn't, doing, don't, down, during, each, few, for, from, further, had, hadn't, has, hasn't, have, haven't, having, he, he'd, he'll, he's, her, here, here's, hers, herself, him, himself, his, how, how's, i, i'd, i'll, i'm, i've, if, in, into, is, isn't, it, it's, its, itself, let's, me, more, most, mustn't, my, myself, no, nor, not, of, off, on, once, only, or, other, ought, our, ours ourselves, out, over, own, same, shan't, she, she'd, she'll, she's, should, shouldn't, so, some, such, than, that, that's, the, their, theirs, them, themselves, then, there, there's, these, they, they'd, they'll, they're, they've, this, those, through, to, too, under, until, up, very, was, wasn't, we, we'd, we'll, we're, we've, were, weren't, what, what's, when, when's, where, where's, which, while, who, who's, whom, why, why's, with, won't, would, wouldn't, you, you'd, you'll, you're, you've, your, yours, yourself, yourselves]
We use ngrams for the tokenisation. We also want to ignore all words which start with @, # or http and remove the endings: ed, ing, ly, s, so we do this with our custom preprocessor.
val tokenizerFactory = new DefaultTokenizerFactory()
//val ngramTokenizerFactory = new NGramTokenizerFactory(tokenizerFactory,1,3)
class TwitterPreprocessor extends TokenPreProcess {
var pp = new CommonPreprocessor()
var ep = new EndingPreProcessor()
def preProcess(token:String):String = {
var result = ""
if (token.startsWith("@")||token.startsWith("#")||token.startsWith("http") ){
result = ""
} else {
result = ep.preProcess(pp.preProcess(token))
}
return result
}
}
//ngramTokenizerFactory.setTokenPreProcessor(new TwitterPreprocessor())
tokenizerFactory.setTokenPreProcessor(new TwitterPreprocessor())
""
val pp = new TwitterPreprocessor()
println("- "+pp.preProcess("valid"))
println("- "+pp.preProcess("https://test.com"))
println("- "+pp.preProcess("houses"))
println("- "+pp.preProcess("heavenly"))
println("- "+pp.preProcess("#test"))
println("- "+pp.preProcess("@test"))
""
- valid - - house - heaven - -
Now we have everything that we need for the training of our network. We start the training! Please note that we defined the epochs in our ParagraphVectors, we we do not need to call the fit() method in a loop.
var iterator = new SimpleLabelAwareIterator(trainingData)
val paragraphVectors = new ParagraphVectors.Builder()
.learningRate(0.025)
.minLearningRate(0.001)
.batchSize(1000)
.minWordFrequency(10)
.stopWords(stopWords)
.epochs(30)
.iterate(iterator)
.trainWordVectors(true)
.tokenizerFactory(tokenizerFactory)
.build()
// Start model training
paragraphVectors.fit()
null
We save the model with the help of the WordVectorSerializer class:
WordVectorSerializer.writeWord2VecModel(paragraphVectors,"sentiment140.model")
After the training we want to validate if our Doc2Vec does something useful:
paragraphVectors.predict(new LabelledTwitterDocument("Microsoft looses customers",""))
positive
paragraphVectors.predict(new LabelledTwitterDocument("Microsoft wins contract",""))
positive
var errors = 0;
var confusionMatrix = new ConfusionMatrix();
var iteratorTest = new SimpleLabelAwareIterator(testData)
while(iteratorTest.hasNext()){
var doc = iteratorTest.next()
try {
var predictedLabel = paragraphVectors.predict(doc)
if (!doc.getContent().isEmpty){
confusionMatrix.increaseValue(doc.getLabel, predictedLabel, 1);
}
} catch {
case e: Exception => errors += 1
}
}
println("---------------")
println(s"Errors: $errors")
println(s"Precision: ${confusionMatrix.getAvgPrecision()}")
println(s"Recall: ${confusionMatrix.getAvgRecall()}")
println(confusionMatrix);
confusionMatrix.printNiceResults()
--------------- Errors: 1001 Precision: 0.7336231169089479 Recall: 0.733607324851242 ↓gold\pred→ negative positive negative 29069 10240 positive 10807 28883
Macro F-measure: 0.734, (CI at .95: 0.003), micro F-measure (acc): 0.734