Regression Example

Declare Factory

Initiate the TMVA library, get the data sample from github, and create a factory to do the regression.

In [1]:
TMVA::Tools::Instance();

auto inputFile = TFile::Open("https://raw.githubusercontent.com/iml-wg/tmvatutorials/master/inputdata.root");
auto outputFile = TFile::Open("TMVAOutputBDT.root", "RECREATE");

TMVA::Factory factory("TMVARegression", outputFile,
                      "!V:!Silent:Color:DrawProgressBar:AnalysisType=Regression" );

Declare DataLoader

Define the features and the target for the regression.

In [2]:
TMVA::DataLoader loader("dataset"); 

// Add the feature variables, names reference branches in inputFile ttree
loader.AddVariable("var1");
loader.AddVariable("var2");
loader.AddVariable("var3");
loader.AddVariable("var4");
loader.AddVariable("var5 := var1-var3"); // create new features
loader.AddVariable("var6 := var1+var2");

loader.AddTarget( "target := var2+var3" ); // define the target for the regression

Setup Dataset

Link dataloader to dataset.

In [3]:
TTree *tree;
inputFile->GetObject("Sig", tree);

TCut mycuts = ""; // e.g. TCut mycuts = "abs(var1)<0.5";

loader.AddRegressionTree(tree, 1.0);   // link the TTree to the loader, weight for each event  = 1
loader.PrepareTrainingAndTestTree(mycuts,
                                   "nTrain_Regression=1000:nTest_Regression=1000:SplitMode=Random:NormMode=NumEvents:!V" );
DataSetInfo              : [dataset] : Added class "Regression"
                         : Add Tree Sig of type Regression with 6000 events
                         : Dataset[dataset] : Class index : 0  name : Regression

Book The Regression Method

Book the method for regression. Here we choose the Boosted Decision Tree model. You have to use gradient boosted trees for regression, hence the BDTG and BoostType=Grad.

Define the hyperparameters: ntrees, boosttype, shrinkage, and the depth. Also define the loss function you want to use: 'AbsoluteDeviation', 'Huber', or 'LeastSquares'. nCuts determines how finely to look at each feature. Larger values take more time, but you may get more accurate results.

In [4]:
// Boosted Decision Trees 
factory.BookMethod(&loader,TMVA::Types::kBDT, "BDTG",
                   TString("!H:!V:NTrees=64::BoostType=Grad:Shrinkage=0.3:nCuts=20:MaxDepth=4:")+
                   TString("RegressionLossFunctionBDTG=AbsoluteDeviation"));
Factory                  : Booking method: BDTG
                         : 
                         : the option *InverseBoostNegWeights* does not exist for BoostType=Grad --> change
                         : to new default for GradBoost *Pray*
DataSetFactory           : [dataset] : Number of events in input trees
                         : 
                         : Number of training and testing events
                         : ---------------------------------------------------------------------------
                         : Regression -- training events            : 1000
                         : Regression -- testing events             : 1000
                         : Regression -- training and testing events: 2000
                         : 
DataSetInfo              : Correlation matrix (Regression):
                         : --------------------------------------------------------------
                         :               var1    var2    var3    var4 var1-var3 var1+var2
                         :      var1:  +1.000  +0.407  +0.584  +0.809    +0.463    +0.836
                         :      var2:  +0.407  +1.000  +0.688  +0.744    -0.304    +0.842
                         :      var3:  +0.584  +0.688  +1.000  +0.848    -0.450    +0.759
                         :      var4:  +0.809  +0.744  +0.848  +1.000    -0.035    +0.925
                         : var1-var3:  +0.463  -0.304  -0.450  -0.035    +1.000    +0.091
                         : var1+var2:  +0.836  +0.842  +0.759  +0.925    +0.091    +1.000
                         : --------------------------------------------------------------
DataSetFactory           : [dataset] :  
                         : 

Train Method

In [5]:
factory.TrainAllMethods();
Factory                  : Train all methods
Factory                  : [dataset] : Create Transformation "I" with events from all classes.
                         : 
                         : Transformation, Variable selection : 
                         : Input : variable 'var1' <---> Output : variable 'var1'
                         : Input : variable 'var2' <---> Output : variable 'var2'
                         : Input : variable 'var3' <---> Output : variable 'var3'
                         : Input : variable 'var4' <---> Output : variable 'var4'
                         : Input : variable 'var5' <---> Output : variable 'var5'
                         : Input : variable 'var6' <---> Output : variable 'var6'
TFHandler_Factory        : Variable        Mean        RMS   [        Min        Max ]
                         : -----------------------------------------------------------
                         :     var1:    0.23134    0.98776   [    -3.3494     3.0772 ]
                         :     var2:    0.29253     1.0040   [    -3.1385     3.6372 ]
                         :     var3:    0.38554    0.98017   [    -2.6377     3.2477 ]
                         :     var4:    0.79277     1.0743   [    -2.5804     4.2226 ]
                         :     var5:   -0.15420    0.89789   [    -2.9267     3.0035 ]
                         :     var6:    0.52387     1.6706   [    -5.0563     5.2734 ]
                         :   target:    0.67807     1.8230   [    -4.6696     6.0522 ]
                         : -----------------------------------------------------------
                         : Ranking input variables (method unspecific)...
IdTransformation         : Ranking result (top variable is best ranked)
                         : --------------------------------------------
                         : Rank : Variable  : |Correlation with target|
                         : --------------------------------------------
                         :    1 : var2      : 9.208e-01
                         :    2 : var3      : 9.167e-01
                         :    3 : var6      : 8.715e-01
                         :    4 : var4      : 8.654e-01
                         :    5 : var1      : 5.379e-01
                         :    6 : var5      : 4.089e-01
                         : --------------------------------------------
IdTransformation         : Ranking result (top variable is best ranked)
                         : -------------------------------------
                         : Rank : Variable  : Mutual information
                         : -------------------------------------
                         :    1 : var3      : 2.536e+00
                         :    2 : var5      : 2.526e+00
                         :    3 : var6      : 2.474e+00
                         :    4 : var1      : 2.467e+00
                         :    5 : var4      : 2.461e+00
                         :    6 : var2      : 2.415e+00
                         : -------------------------------------
IdTransformation         : Ranking result (top variable is best ranked)
                         : ------------------------------------
                         : Rank : Variable  : Correlation Ratio
                         : ------------------------------------
                         :    1 : var3      : 5.910e-01
                         :    2 : var2      : 5.823e-01
                         :    3 : var4      : 5.240e-01
                         :    4 : var6      : 5.172e-01
                         :    5 : var1      : 2.870e-01
                         :    6 : var5      : 2.781e-01
                         : ------------------------------------
IdTransformation         : Ranking result (top variable is best ranked)
                         : ----------------------------------------
                         : Rank : Variable  : Correlation Ratio (T)
                         : ----------------------------------------
                         :    1 : var2      : 7.967e-01
                         :    2 : var3      : 7.744e-01
                         :    3 : var4      : 6.782e-01
                         :    4 : var6      : 6.646e-01
                         :    5 : var1      : 2.613e-01
                         :    6 : var5      : 2.423e-01
                         : ----------------------------------------
Factory                  : Train method: BDTG for Regression
                         : 
                         : Regression Loss Function: AbsoluteDeviation
                         : Training 64 Decision Trees ... patience please
                         : [>>>>>>>>>>>....] (73%, time left: 0 sec) 
                         : Elapsed time for training with 1000 events: 0.119 sec         
                         : Dataset[dataset] : Create results for training
                         : Dataset[dataset] : Evaluation of BDTG on training sample
                         : Dataset[dataset] : Elapsed time for evaluation of 1000 events: 0.0683 sec       
                         : Create variable histograms
                         : Create regression target histograms
                         : Create regression average deviation
                         : Results created
                         : Creating xml weight file: dataset/weights/TMVARegression_BDTG.weights.xml
                         : [>>>>>>>>>>>>>>>] (100%, time left: 0 sec) 
Factory                  : Training finished
                         : 
Factory                  : === Destroy and recreate all methods via weight files for testing ===
                         : 

Test and Evaluate the Model

In [6]:
factory.TestAllMethods();
factory.EvaluateAllMethods();
Factory                  : Test all methods
Factory                  : Test method: BDTG for Regression performance
                         : 
                         : Dataset[dataset] : Create results for testing
                         : Dataset[dataset] : Evaluation of BDTG on testing sample
                         : [>>>>>>>>>>>>>>>] (100%, time left: 0 sec) 
                         : Dataset[dataset] : Elapsed time for evaluation of 1000 events: 0.109 sec       
                         : Create variable histograms
                         : Create regression target histograms
                         : Create regression average deviation
                         : Results created
Factory                  : Evaluate all methods
                         : Evaluate regression method: BDTG
TFHandler_BDTG           : Variable        Mean        RMS   [        Min        Max ]
                         : -----------------------------------------------------------
                         :     var1:    0.18427     1.0189   [    -3.3780     3.2875 ]
                         :     var2:    0.28570    0.98438   [    -3.2880     3.4734 ]
                         :     var3:    0.41410    0.99893   [    -2.6232     4.6422 ]
                         :     var4:    0.79156     1.0958   [    -2.9492     4.0073 ]
                         :     var5:   -0.22983    0.90497   [    -2.9587     2.6624 ]
                         :     var6:    0.46997     1.6680   [    -5.1623     5.5897 ]
                         :   target:    0.69980     1.8164   [    -5.9113     7.0144 ]
                         : -----------------------------------------------------------
                         : 
                         : Evaluation results ranked by smallest RMS on test sample:
                         : ("Bias" quotes the mean deviation of the regression from true target.
                         :  "MutInf" is the "Mutual Information" between regression and target.
                         :  Indicated by "_T" are the corresponding "truncated" quantities ob-
                         :  tained when removing events deviating more than 2sigma from average.)
                         : --------------------------------------------------------------------------------------------------
                         : --------------------------------------------------------------------------------------------------
                         : dataset              BDTG           : -0.00179 -0.00410    0.164    0.133  |  2.447  2.478
                         : --------------------------------------------------------------------------------------------------
                         : 
                         : Evaluation results ranked by smallest RMS on training sample:
                         : (overtraining check)
                         : --------------------------------------------------------------------------------------------------
                         : DataSet Name:         MVA Method:        <Bias>   <Bias_T>    RMS    RMS_T  |  MutInf MutInf_T
                         : --------------------------------------------------------------------------------------------------
                         : dataset              BDTG           :  0.00357  0.00218    0.116   0.0927  |  2.694  2.800
                         : --------------------------------------------------------------------------------------------------
                         : 
Dataset:dataset          : Created tree 'TestTree' with 1000 events
                         : 
Dataset:dataset          : Created tree 'TrainTree' with 1000 events
                         : 
Factory                  : Thank you for using TMVA!
                         : For citation information, please visit: http://tmva.sf.net/citeTMVA.html

Gather and Plot the Results

Let's plot the residuals for the BDTG predictions. First, close the output file so that it saves to disk and we can open it without issue. Then get the results on the test set. Finally, plot the residuals.

In [7]:
%jsroot on
outputFile->Close();
auto resultsFile = TFile::Open("TMVAOutputBDT.root");
auto resultsTree = resultsFile->Get("dataset/TestTree"); 
TCanvas c;
resultsTree->Draw("BDTG-target"); // BDTG is the predicted value, target is the true value
c.Draw();
In [ ]: