DNN Example

Declare Factory

In [ ]:
from ROOT import TMVA, TFile, TTree, TCut, TString
In [ ]:
TMVA.Tools.Instance()

inputFile = TFile.Open("https://raw.githubusercontent.com/iml-wg/tmvatutorials/master/inputdata.root")
outputFile = TFile.Open("TMVAOutputDNN.root", "RECREATE")

factory = TMVA.Factory("TMVAClassification", outputFile,
        "!V:!Silent:Color:!DrawProgressBar:AnalysisType=Classification" )

Declare Variables in DataLoader

In [ ]:
loader = TMVA.DataLoader("dataset_dnn")

loader.AddVariable("var1")
loader.AddVariable("var2")
loader.AddVariable("var3")
loader.AddVariable("var4")
loader.AddVariable("var5 := var1-var3")
loader.AddVariable("var6 := var1+var2")

Setup Dataset(s)

In [ ]:
tsignal = inputFile.Get("Sig")
tbackground = inputFile.Get("Bkg")

loader.AddSignalTree(tsignal)
loader.AddBackgroundTree(tbackground) 
loader.PrepareTrainingAndTestTree(TCut(""),
        "nTrain_Signal=1000:nTrain_Background=1000:SplitMode=Random:NormMode=NumEvents:!V")

Configure Network Layout

In [ ]:
# General layout
layoutString = TString("Layout=TANH|128,TANH|128,TANH|128,LINEAR");

# Training strategies
training0 = TString("LearningRate=1e-1,Momentum=0.9,Repetitions=1,"
                        "ConvergenceSteps=2,BatchSize=256,TestRepetitions=10,"
                        "WeightDecay=1e-4,Regularization=L2,"
                        "DropConfig=0.0+0.5+0.5+0.5, Multithreading=True")
training1 = TString("LearningRate=1e-2,Momentum=0.9,Repetitions=1,"
                        "ConvergenceSteps=2,BatchSize=256,TestRepetitions=10,"
                        "WeightDecay=1e-4,Regularization=L2,"
                        "DropConfig=0.0+0.0+0.0+0.0, Multithreading=True")
trainingStrategyString = TString("TrainingStrategy=")
trainingStrategyString += training0 + TString("|") + training1

# General Options
dnnOptions = TString("!H:!V:ErrorStrategy=CROSSENTROPY:VarTransform=N:"
        "WeightInitialization=XAVIERUNIFORM")
dnnOptions.Append(":")
dnnOptions.Append(layoutString)
dnnOptions.Append(":")
dnnOptions.Append(trainingStrategyString)

Booking Methods

In [ ]:
# Standard implementation, no dependencies.
stdOptions =  dnnOptions + ":Architecture=CPU"
factory.BookMethod(loader, TMVA.Types.kDNN, "DNN", stdOptions)

# CPU implementation, using BLAS
#cpuOptions = dnnOptions + ":Architecture=CPU"
#factory.BookMethod(loader, TMVA.Types.kDNN, "DNN CPU", cpuOptions)

Train Methods

In [ ]:
factory.TrainAllMethods()

Test and Evaluate Methods

In [ ]:
factory.TestAllMethods()
factory.EvaluateAllMethods()

Plot ROC Curve

We enable JavaScript visualisation for the plots

In [ ]:
%jsroot on
In [ ]:
c = factory.GetROCCurve(loader)
c.Draw()