T M V A Minimal Classification

Minimal self-contained example for setting up TMVA with binary classification.

This is intended as a simple foundation to build on. It assumes you are familiar with TMVA already. As such concepts like the Factory, the DataLoader and others are not explained. For descriptions and tutorials use the TMVA User's Guide (https://root.cern.ch/root-user-guides-and-manuals under TMVA) or the more detailed examples provided with TMVA e.g. TMVAClassification.C.

Sets up a minimal binary classification example with two slightly overlapping 2-D gaussian distributions and trains a BDT classifier to discriminate the data.

  • Project : TMVA - a ROOT-integrated toolkit for multivariate data analysis
  • Package : TMVA
  • Root Macro: TMVAMinimalClassification.C

Author: Kim Albertsson
This notebook tutorial was automatically generated with ROOTBOOK-izer from the macro found in the ROOT repository on Saturday, September 18, 2021 at 09:47 AM.

In [1]:
%%cpp -d
#include "TMVA/DataLoader.h"
#include "TMVA/Factory.h"

#include "TFile.h"
#include "TString.h"
#include "TTree.h"

Minimal setup for performing binary classification in TMVA.

Modify the setup to your liking and run with root -l -b -q TMVAMinimalClassification.C. This will generate an output file "out.root" that can be viewed with root -l -e 'TMVA::TMVAGui("out.root")'.

Helper function to generate 2-D gaussian data points and fill to a ROOT TTree.

Arguments: nPoints Number of points to generate. offset Mean of the generated numbers scale Standard deviation of the generated numbers. seed Seed for random number generator. Use seed=0 for random seed. Returns a TTree ready to be used as input to TMVA.

In [2]:
%%cpp -d
TTree *genTree(Int_t nPoints, Double_t offset, Double_t scale, UInt_t seed = 100)
{
   TRandom rng(seed);
   Double_t x = 0;
   Double_t y = 0;

   TTree *data = new TTree();
   data->Branch("x", &x, "x/D");
   data->Branch("y", &y, "y/D");

   for (Int_t n = 0; n < nPoints; ++n) {
      x = rng.Rndm() * scale;
      y = offset + rng.Rndm() * scale;
      data->Fill();
   }

   // Important: Disconnects the tree from the memory locations of x and y.
   data->ResetBranchAddresses();
   return data;
}
In [3]:
TString outputFilename = "out.root";
TFile *outFile = new TFile(outputFilename, "RECREATE");

Data generation

In [4]:
TTree *signalTree = genTree(1000, 0.0, 2.0, 100);
TTree *backgroundTree = genTree(1000, 1.0, 2.0, 101);

TString factoryOptions = "AnalysisType=Classification";
TMVA::Factory factory{"", outFile, factoryOptions};

TMVA::DataLoader dataloader{"dataset"};

Data specification

In [5]:
dataloader.AddVariable("x", 'D');
dataloader.AddVariable("y", 'D');

dataloader.AddSignalTree(signalTree, 1.0);
dataloader.AddBackgroundTree(backgroundTree, 1.0);

TCut signalCut = "";
TCut backgroundCut = "";
TString datasetOptions = "SplitMode=Random";
dataloader.PrepareTrainingAndTestTree(signalCut, backgroundCut, datasetOptions);
<HEADER> DataSetInfo              : [dataset] : Added class "Signal"
                         : Add Tree  of type Signal with 1000 events
<HEADER> DataSetInfo              : [dataset] : Added class "Background"
                         : Add Tree  of type Background with 1000 events

Method specification

In [6]:
TString methodOptions = "";
factory.BookMethod(&dataloader, TMVA::Types::kBDT, "BDT", methodOptions);
<HEADER> Factory                  : Booking method: BDT
                         : 
                         : Rebuilding Dataset dataset
                         : Building event vectors for type 2 Signal
                         : Dataset[dataset] :  create input formulas for tree 
                         : Building event vectors for type 2 Background
                         : Dataset[dataset] :  create input formulas for tree 
<HEADER> DataSetFactory           : [dataset] : Number of events in input trees
                         : 
                         : 
                         : Dataset[dataset] : Weight renormalisation mode: "EqualNumEvents": renormalises all event classes ...
                         : Dataset[dataset] :  such that the effective (weighted) number of events in each class is the same 
                         : Dataset[dataset] :  (and equals the number of events (entries) given for class=0 )
                         : Dataset[dataset] : ... i.e. such that Sum[i=1..N_j]{w_i} = N_classA, j=classA, classB, ...
                         : Dataset[dataset] : ... (note that N_j is the sum of TRAINING events
                         : Dataset[dataset] :  ..... Testing events are not renormalised nor included in the renormalisation factor!)
                         : Number of training and testing events
                         : ---------------------------------------------------------------------------
                         : Signal     -- training events            : 500
                         : Signal     -- testing events             : 500
                         : Signal     -- training and testing events: 1000
                         : Background -- training events            : 500
                         : Background -- testing events             : 500
                         : Background -- training and testing events: 1000
                         : 
<HEADER> DataSetInfo              : Correlation matrix (Signal):
                         : ------------------------
                         :                x       y
                         :       x:  +1.000  +0.030
                         :       y:  +0.030  +1.000
                         : ------------------------
<HEADER> DataSetInfo              : Correlation matrix (Background):
                         : ------------------------
                         :                x       y
                         :       x:  +1.000  -0.022
                         :       y:  -0.022  +1.000
                         : ------------------------
<HEADER> DataSetFactory           : [dataset] :  
                         : 

Training and evaluation

In [7]:
factory.TrainAllMethods();
factory.TestAllMethods();
factory.EvaluateAllMethods();
<HEADER> Factory                  : Train all methods
<HEADER> Factory                  : [dataset] : Create Transformation "I" with events from all classes.
                         : 
<HEADER>                          : Transformation, Variable selection : 
                         : Input : variable 'x' <---> Output : variable 'x'
                         : Input : variable 'y' <---> Output : variable 'y'
<HEADER> TFHandler_Factory        : Variable        Mean        RMS   [        Min        Max ]
                         : -----------------------------------------------------------
                         :        x:     1.0229    0.57835   [ 0.00044777     1.9988 ]
                         :        y:     1.4942    0.76640   [   0.014777     2.9933 ]
                         : -----------------------------------------------------------
                         : Ranking input variables (method unspecific)...
<HEADER> IdTransformation         : Ranking result (top variable is best ranked)
                         : --------------------------
                         : Rank : Variable  : Separation
                         : --------------------------
                         :    1 : y         : 5.413e-01
                         :    2 : x         : 4.319e-02
                         : --------------------------
<HEADER> Factory                  : Train method: BDT for Classification
                         : 
<HEADER> BDT                      : #events: (reweighted) sig: 500 bkg: 500
                         : #events: (unweighted) sig: 500 bkg: 500
                         : Training 800 Decision Trees ... patience please
                         : Elapsed time for training with 1000 events: 0.4 sec         
<HEADER> BDT                      : [dataset] : Evaluation of BDT on training sample (1000 events)
                         : Elapsed time for evaluation of 1000 events: 0.0687 sec       
                         : Creating xml weight file: dataset/weights/_BDT.weights.xml
                         : Creating standalone class: dataset/weights/_BDT.class.C
                         : out.root:/dataset/Method_BDT/BDT
<HEADER> Factory                  : Training finished
                         : 
                         : Ranking input variables (method specific)...
<HEADER> BDT                      : Ranking result (top variable is best ranked)
                         : -----------------------------------
                         : Rank : Variable  : Variable Importance
                         : -----------------------------------
                         :    1 : y         : 5.011e-01
                         :    2 : x         : 4.989e-01
                         : -----------------------------------
<HEADER> Factory                  : === Destroy and recreate all methods via weight files for testing ===
                         : 
                         : Reading weight file: dataset/weights/_BDT.weights.xml
<HEADER> Factory                  : Test all methods
<HEADER> Factory                  : Test method: BDT for Classification performance
                         : 
<HEADER> BDT                      : [dataset] : Evaluation of BDT on testing sample (1000 events)
                         : Elapsed time for evaluation of 1000 events: 0.0546 sec       
<HEADER> Factory                  : Evaluate all methods
<HEADER> Factory                  : Evaluate classifier: BDT
                         : 
<HEADER> BDT                      : [dataset] : Loop over test events and fill histograms with classifier response...
                         : 
<HEADER> TFHandler_BDT            : Variable        Mean        RMS   [        Min        Max ]
                         : -----------------------------------------------------------
                         :        x:     1.0136    0.57754   [  0.0011208     1.9999 ]
                         :        y:     1.4938    0.75135   [  0.0054384     2.9981 ]
                         : -----------------------------------------------------------
                         : 
                         : Evaluation results ranked by best signal efficiency and purity (area)
                         : -------------------------------------------------------------------------------------------------------------------
                         : DataSet       MVA                       
                         : Name:         Method:          ROC-integ
                         : dataset       BDT            : 0.870
                         : -------------------------------------------------------------------------------------------------------------------
                         : 
                         : Testing efficiency compared to training efficiency (overtraining check)
                         : -------------------------------------------------------------------------------------------------------------------
                         : DataSet              MVA              Signal efficiency: from test sample (from training sample) 
                         : Name:                Method:          @B=0.01             @B=0.10            @B=0.30   
                         : -------------------------------------------------------------------------------------------------------------------
                         : dataset              BDT            : 0.495 (0.675)       0.622 (0.754)      0.794 (0.908)
                         : -------------------------------------------------------------------------------------------------------------------
                         : 
<HEADER> Dataset:dataset          : Created tree 'TestTree' with 1000 events
                         : 
<HEADER> Dataset:dataset          : Created tree 'TrainTree' with 1000 events
                         : 
<HEADER> Factory                  : Thank you for using TMVA!
                         : For citation information, please visit: http://tmva.sf.net/citeTMVA.html
0%, time left: unknown
6%, time left: 0 sec
12%, time left: 0 sec
18%, time left: 0 sec
25%, time left: 0 sec
31%, time left: 0 sec
37%, time left: 0 sec
43%, time left: 0 sec
50%, time left: 0 sec
56%, time left: 0 sec
62%, time left: 0 sec
68%, time left: 0 sec
75%, time left: 0 sec
81%, time left: 0 sec
87%, time left: 0 sec
93%, time left: 0 sec
0%, time left: unknown
7%, time left: 0 sec
13%, time left: 0 sec
19%, time left: 0 sec
25%, time left: 0 sec
32%, time left: 0 sec
38%, time left: 0 sec
44%, time left: 0 sec
50%, time left: 0 sec
57%, time left: 0 sec
63%, time left: 0 sec
69%, time left: 0 sec
75%, time left: 0 sec
82%, time left: 0 sec
88%, time left: 0 sec
94%, time left: 0 sec
0%, time left: unknown
7%, time left: 0 sec
13%, time left: 0 sec
19%, time left: 0 sec
25%, time left: 0 sec
32%, time left: 0 sec
38%, time left: 0 sec
44%, time left: 0 sec
50%, time left: 0 sec
57%, time left: 0 sec
63%, time left: 0 sec
69%, time left: 0 sec
75%, time left: 0 sec
82%, time left: 0 sec
88%, time left: 0 sec
94%, time left: 0 sec

Clean up

In [8]:
outFile->Close();

delete outFile;
delete signalTree;
delete backgroundTree;