T M V A Cross Validation Application

This macro provides an example of how to use TMVA for k-folds cross evaluation in application.

This requires that CrossValidation was run with a deterministic split, such as "...:splitExpr=int([eventID])%int([numFolds]):...".

  • Project : TMVA - a ROOT-integrated toolkit for multivariate data analysis
  • Package : TMVA
  • Root Macro: TMVACrossValidationApplication

Author: Kim Albertsson (adapted from code originally by Andreas Hoecker)
This notebook tutorial was automatically generated with ROOTBOOK-izer from the macro found in the ROOT repository on Thursday, June 17, 2021 at 06:05 PM.

In [1]:
%%cpp -d
#include <cstdlib>
#include <iostream>
#include <map>
#include <string>

#include "TChain.h"
#include "TFile.h"
#include "TTree.h"
#include "TString.h"
#include "TObjString.h"
#include "TSystem.h"
#include "TROOT.h"

#include "TMVA/Factory.h"
#include "TMVA/DataLoader.h"
#include "TMVA/Tools.h"
#include "TMVA/TMVAGui.h"

Helper function to load data into TTrees.

In [2]:
%%cpp -d
TTree *fillTree(TTree * tree, Int_t nPoints, Double_t offset, Double_t scale, UInt_t seed = 100)
{
   TRandom3 rng(seed);
   Float_t x = 0;
   Float_t y = 0;
   Int_t eventID = 0;

   tree->SetBranchAddress("x", &x);
   tree->SetBranchAddress("y", &y);
   tree->SetBranchAddress("eventID", &eventID);

   for (Int_t n = 0; n < nPoints; ++n) {
      x = rng.Gaus(offset, scale);
      y = rng.Gaus(offset, scale);

      // For our simple example it is enough that the id's are uniformly
      // distributed and independent of the data.
      ++eventID;

      tree->Fill();
   }

   // Important: Disconnects the tree from the memory locations of x and y.
   tree->ResetBranchAddresses();
   return tree;
}

This loads the library

In [3]:
TMVA::Tools::Instance();

Set up the tmva::reader

In [4]:
TMVA::Reader *reader = new TMVA::Reader("!Color:!Silent:!V");

Float_t x;
Float_t y;
Int_t eventID;

reader->AddVariable("x", &x);
reader->AddVariable("y", &y);
reader->AddSpectator("eventID", &eventID);

Book the serialised methods

In [5]:
TString jobname("TMVACrossValidation");
{
   TString methodName = "BDTG";
   TString weightfile = TString("dataset/weights/") + jobname + "_" + methodName + TString(".weights.xml");

   Bool_t weightfileExists = (gSystem->AccessPathName(weightfile) == kFALSE);
   if (weightfileExists) {
      reader->BookMVA(methodName, weightfile);
   } else {
      std::cout << "Weightfile for method " << methodName << " not found."
                   " Did you run TMVACrossValidation with a specified"
                   " splitExpr?" << std::endl;
      exit(0);
   }

}
{
   TString methodName = "Fisher";
   TString weightfile = TString("dataset/weights/") + jobname + "_" + methodName + TString(".weights.xml");

   Bool_t weightfileExists = (gSystem->AccessPathName(weightfile) == kFALSE);
   if (weightfileExists) {
      reader->BookMVA(methodName, weightfile);
   } else {
      std::cout << "Weightfile for method " << methodName << " not found."
                   " Did you run TMVACrossValidation with a specified"
                   " splitExpr?" << std::endl;
      exit(0);
   }
}
                         : Booking "BDTG" of type "CrossValidation" from dataset/weights/TMVACrossValidation_BDTG.weights.xml.
                         : Reading weight file: dataset/weights/TMVACrossValidation_BDTG.weights.xml
<HEADER> DataSetInfo              : [Default] : Added class "Signal"
<HEADER> DataSetInfo              : [Default] : Added class "Background"
                         : Reading weightfile: dataset/weights/TMVACrossValidation_BDTG_fold1.weights.xml
                         : Reading weight file: dataset/weights/TMVACrossValidation_BDTG_fold1.weights.xml
                         : Reading weightfile: dataset/weights/TMVACrossValidation_BDTG_fold2.weights.xml
                         : Reading weight file: dataset/weights/TMVACrossValidation_BDTG_fold2.weights.xml
                         : Booked classifier "BDTG" of type: "CrossValidation"
                         : Booking "Fisher" of type "CrossValidation" from dataset/weights/TMVACrossValidation_Fisher.weights.xml.
                         : Reading weight file: dataset/weights/TMVACrossValidation_Fisher.weights.xml
                         : Reading weightfile: dataset/weights/TMVACrossValidation_Fisher_fold1.weights.xml
                         : Reading weight file: dataset/weights/TMVACrossValidation_Fisher_fold1.weights.xml
                         : Reading weightfile: dataset/weights/TMVACrossValidation_Fisher_fold2.weights.xml
                         : Reading weight file: dataset/weights/TMVACrossValidation_Fisher_fold2.weights.xml
                         : Booked classifier "Fisher" of type: "CrossValidation"

Load data

In [6]:
TTree *tree = new TTree();
tree->Branch("x", &x, "x/F");
tree->Branch("y", &y, "y/F");
tree->Branch("eventID", &eventID, "eventID/I");

fillTree(tree, 1000, 1.0, 1.0, 100);
fillTree(tree, 1000, -1.0, 1.0, 101);
tree->SetBranchAddress("x", &x);
tree->SetBranchAddress("y", &y);
tree->SetBranchAddress("eventID", &eventID);

Prepare histograms

In [7]:
Int_t nbin = 100;
TH1F histBDTG{"BDTG", "BDTG", nbin, -1, 1};
TH1F histFisher{"Fisher", "Fisher", nbin, -1, 1};

Evaluate classifiers

In [8]:
for (Long64_t ievt = 0; ievt < tree->GetEntries(); ievt++) {
   tree->GetEntry(ievt);

   Double_t valBDTG = reader->EvaluateMVA("BDTG");
   Double_t valFisher = reader->EvaluateMVA("Fisher");

   histBDTG.Fill(valBDTG);
   histFisher.Fill(valFisher);
}

tree->ResetBranchAddresses();
delete tree;

if (!gROOT->IsBatch()) {
   auto c = new TCanvas();
   c->Divide(2,1);
   c->cd(1);
   histBDTG.DrawClone();
   c->cd(2);
   histFisher.DrawClone();
}
else
{ // Write histograms to output file
   TFile *target = new TFile("TMVACrossEvaluationApp.root", "RECREATE");
   histBDTG.Write();
   histFisher.Write();
   target->Close();
   delete target;
}

delete reader;

return 0;
                         : Rebuilding Dataset Default