#!/usr/bin/env python
# coding: utf-8
#
#
#
# # DNN Example
# ## Declare Factory
# In[ ]:
from ROOT import TMVA, TFile, TTree, TCut, TString
# In[ ]:
TMVA.Tools.Instance()
inputFile = TFile.Open("https://raw.githubusercontent.com/iml-wg/tmvatutorials/master/inputdata.root")
outputFile = TFile.Open("TMVAOutputDNN.root", "RECREATE")
factory = TMVA.Factory("TMVAClassification", outputFile,
"!V:!Silent:Color:!DrawProgressBar:AnalysisType=Classification" )
# ## Declare Variables in DataLoader
# In[ ]:
loader = TMVA.DataLoader("dataset_dnn")
loader.AddVariable("var1")
loader.AddVariable("var2")
loader.AddVariable("var3")
loader.AddVariable("var4")
loader.AddVariable("var5 := var1-var3")
loader.AddVariable("var6 := var1+var2")
# ## Setup Dataset(s)
# In[ ]:
tsignal = inputFile.Get("Sig")
tbackground = inputFile.Get("Bkg")
loader.AddSignalTree(tsignal)
loader.AddBackgroundTree(tbackground)
loader.PrepareTrainingAndTestTree(TCut(""),
"nTrain_Signal=1000:nTrain_Background=1000:SplitMode=Random:NormMode=NumEvents:!V")
# # Configure Network Layout
# In[ ]:
# General layout
layoutString = TString("Layout=TANH|128,TANH|128,TANH|128,LINEAR");
# Training strategies
training0 = TString("LearningRate=1e-1,Momentum=0.9,Repetitions=1,"
"ConvergenceSteps=2,BatchSize=256,TestRepetitions=10,"
"WeightDecay=1e-4,Regularization=L2,"
"DropConfig=0.0+0.5+0.5+0.5, Multithreading=True")
training1 = TString("LearningRate=1e-2,Momentum=0.9,Repetitions=1,"
"ConvergenceSteps=2,BatchSize=256,TestRepetitions=10,"
"WeightDecay=1e-4,Regularization=L2,"
"DropConfig=0.0+0.0+0.0+0.0, Multithreading=True")
trainingStrategyString = TString("TrainingStrategy=")
trainingStrategyString += training0 + TString("|") + training1
# General Options
dnnOptions = TString("!H:!V:ErrorStrategy=CROSSENTROPY:VarTransform=N:"
"WeightInitialization=XAVIERUNIFORM")
dnnOptions.Append(":")
dnnOptions.Append(layoutString)
dnnOptions.Append(":")
dnnOptions.Append(trainingStrategyString)
# # Booking Methods
# In[ ]:
# Standard implementation, no dependencies.
stdOptions = dnnOptions + ":Architecture=CPU"
factory.BookMethod(loader, TMVA.Types.kDNN, "DNN", stdOptions)
# CPU implementation, using BLAS
#cpuOptions = dnnOptions + ":Architecture=CPU"
#factory.BookMethod(loader, TMVA.Types.kDNN, "DNN CPU", cpuOptions)
# ## Train Methods
# In[ ]:
factory.TrainAllMethods()
# ## Test and Evaluate Methods
# In[ ]:
factory.TestAllMethods()
factory.EvaluateAllMethods()
# ## Plot ROC Curve
# We enable JavaScript visualisation for the plots
# In[ ]:
get_ipython().run_line_magic('jsroot', 'on')
# In[ ]:
c = factory.GetROCCurve(loader)
c.Draw()