#include "TRInterface.h" #include "TMVA/MethodC50.h" #include "TMVA/MethodRSNNS.h" #include "TMVA/MethodRXGB.h" TMVA::Tools::Instance(); auto inputFile = TFile::Open("https://raw.githubusercontent.com/iml-wg/tmvatutorials/master/inputdata.root"); auto outputFile = TFile::Open("TMVAOutputCV.root", "RECREATE"); TMVA::Factory factory("TMVAClassification", outputFile, "!V:ROC:!Correlations:!Silent:Color:!DrawProgressBar:AnalysisType=Classification" ); TMVA::DataLoader loader("dataset"); //adding variables to dataset loader.AddVariable("var1"); loader.AddVariable("var2"); loader.AddVariable("var3"); loader.AddVariable("var4"); TTree *tsignal, *tbackground; inputFile->GetObject("Sig", tsignal); inputFile->GetObject("Bkg", tbackground); TCut mycuts, mycutb; loader.AddSignalTree (tsignal, 1); //signal weight = 1 loader.AddBackgroundTree (tbackground, 1); //background weight = 1 loader.PrepareTrainingAndTestTree(mycuts, mycutb, "nTrain_Signal=1000:nTrain_Background=1000:SplitMode=Random:NormMode=NumEvents:!V"); //C50 Boosted Decision Trees (BDTs) factory.BookMethod(&loader, TMVA::Types::kC50, "C50", "!H:NTrials=5:Rules=kTRUE:ControlSubSet=kFALSE:ControlBands=10:ControlWinnow=kFALSE:ControlNoGlobalPruning=kTRUE:ControlCF=0.25:ControlMinCases=2:ControlFuzzyThreshold=kTRUE:ControlSample=0:ControlEarlyStopping=kTRUE:!V" ); //Neural Networks using RSNNS package factory.BookMethod(&loader, TMVA::Types::kRSNNS, "RMLP", "!H:VarTransform=N:Size=c(5):Maxit=10:InitFunc=Randomize_Weights:LearnFunc=Std_Backpropagation:LearnFuncParams=c(0.2,0):!V" ); //eXtreme Gradient Boosted XGB Decision Trees factory.BookMethod(&loader, TMVA::Types::kRXGB, "RXGB","!V:NRounds=20:MaxDepth=2:Eta=1" ); //TMVA BDTs factory.BookMethod(&loader,TMVA::Types::kBDT, "BDT", "!V:NTrees=50:MinNodeSize=2.5%:MaxDepth=2:BoostType=AdaBoost:AdaBoostBeta=0.5:UseBaggedBoost:BaggedSampleFraction=0.5:SeparationType=GiniIndex:nCuts=20" ); factory.TrainAllMethods(); factory.TestAllMethods(); factory.EvaluateAllMethods(); %jsroot on auto c = factory.GetROCCurve(&loader); c->Draw();