Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-12-16 10:30:14

0001 // @(#)root/tmva $Id$
0002 // Author: Andreas Hoecker, Peter Speckmayer, Joerg Stelzer, Helge Voss, Kai Voss, Eckhard von Toerne, Jan Therhaag, Omar Zapata, Lorenzo Moneta, Sergei Gleyzer
0003 //NOTE: Based on TMVA::Factory
0004 
0005 /**********************************************************************************
0006  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis       *
0007  * Package: TMVA                                                                  *
0008  * Class  : DataLoader                                                            *
0009  *                                             *
0010  *                                                                                *
0011  * Description:                                                                   *
0012  *      This is a class to load datasets into every booked method                 *
0013  *                                                                                *
0014  * Authors (alphabetical):                                                        *
0015  *      Lorenzo Moneta <Lorenzo.Moneta@cern.ch> - CERN, Switzerland               *
0016  *      Omar Zapata <andresete.chaos@gmail.com>  - ITM/UdeA, Colombia             *
0017  *      Sergei Gleyzer<sergei.gleyzer@cern.ch> - CERN, Switzerland                *
0018  *                                                                                *
0019  * Copyright (c) 2005-2011:                                                       *
0020  *      CERN, Switzerland                                                         *
0021  *      ITM/UdeA, Colombia                                                        *
0022  *                                                                                *
0023  * Redistribution and use in source and binary forms, with or without             *
0024  * modification, are permitted according to the terms listed in LICENSE           *
0025  * (see tmva/doc/LICENSE)                                          *
0026  **********************************************************************************/
0027 
0028 #ifndef ROOT_TMVA_DataLoader
0029 #define ROOT_TMVA_DataLoader
0030 
0031 #include <vector>
0032 #include "TCut.h"
0033 
0034 #include "TMVA/Configurable.h"
0035 #include "TMVA/Types.h"
0036 #include "TMVA/DataSet.h"
0037 
0038 class TFile;
0039 class TTree;
0040 class TH2;
0041 
0042 namespace TMVA {
0043 
0044    class CvSplit;
0045    class DataInputHandler;
0046    class DataSetInfo;
0047    class DataSetManager;
0048    class VariableTransformBase;
0049 
0050    class DataLoader : public Configurable {
0051    public:
0052 
0053        DataLoader(TString thedlName="default");
0054 
0055       // default destructor
0056       virtual ~DataLoader();
0057 
0058 
0059       // add events to training and testing trees
0060       void AddSignalTrainingEvent    ( const std::vector<Double_t>& event, Double_t weight = 1.0 );
0061       void AddBackgroundTrainingEvent( const std::vector<Double_t>& event, Double_t weight = 1.0 );
0062       void AddSignalTestEvent        ( const std::vector<Double_t>& event, Double_t weight = 1.0 );
0063       void AddBackgroundTestEvent    ( const std::vector<Double_t>& event, Double_t weight = 1.0 );
0064       void AddTrainingEvent( const TString& className, const std::vector<Double_t>& event, Double_t weight );
0065       void AddTestEvent    ( const TString& className, const std::vector<Double_t>& event, Double_t weight );
0066       void AddEvent        ( const TString& className, Types::ETreeType tt, const std::vector<Double_t>& event, Double_t weight );
0067       Bool_t UserAssignEvents(UInt_t clIndex);
0068       TTree* CreateEventAssignTrees( const TString& name );
0069 
0070       DataSetInfo& AddDataSet( DataSetInfo& );
0071       DataSetInfo& AddDataSet( const TString&  );
0072       DataSetInfo& GetDataSetInfo();
0073       DataLoader* VarTransform(TString trafoDefinition);
0074 
0075       // special case: signal/background
0076 
0077       // Data input related
0078       void SetInputTrees( const TString& signalFileName, const TString& backgroundFileName,
0079                           Double_t signalWeight=1.0, Double_t backgroundWeight=1.0 );
0080       void SetInputTrees( TTree* inputTree, const TCut& SigCut, const TCut& BgCut );
0081       // Set input trees  at once
0082       void SetInputTrees( TTree* signal, TTree* background,
0083                           Double_t signalWeight=1.0, Double_t backgroundWeight=1.0) ;
0084 
0085       void AddSignalTree( TTree* signal,    Double_t weight=1.0, Types::ETreeType treetype = Types::kMaxTreeType );
0086       void AddSignalTree( TString datFileS, Double_t weight=1.0, Types::ETreeType treetype = Types::kMaxTreeType );
0087       void AddSignalTree( TTree* signal, Double_t weight, const TString& treetype );
0088 
0089       // ... depreciated, kept for backwards compatibility
0090       void SetSignalTree( TTree* signal, Double_t weight=1.0);
0091 
0092       void AddBackgroundTree( TTree* background, Double_t weight=1.0, Types::ETreeType treetype = Types::kMaxTreeType );
0093       void AddBackgroundTree( TString datFileB,  Double_t weight=1.0, Types::ETreeType treetype = Types::kMaxTreeType );
0094       void AddBackgroundTree( TTree* background, Double_t weight, const TString & treetype );
0095 
0096       // ... depreciated, kept for backwards compatibility
0097       void SetBackgroundTree( TTree* background, Double_t weight=1.0 );
0098 
0099       void SetSignalWeightExpression( const TString& variable );
0100       void SetBackgroundWeightExpression( const TString& variable );
0101 
0102       // special case: regression
0103       void AddRegressionTree( TTree* tree, Double_t weight = 1.0,
0104                               Types::ETreeType treetype = Types::kMaxTreeType ) {
0105          AddTree( tree, "Regression", weight, "", treetype );
0106       }
0107 
0108       // general
0109 
0110       // Data input related
0111       void SetTree( TTree* tree, const TString& className, Double_t weight ); ///< deprecated
0112       void AddTree( TTree* tree, const TString& className, Double_t weight=1.0,
0113                     const TCut& cut = "",
0114                     Types::ETreeType tt = Types::kMaxTreeType );
0115       void AddTree( TTree* tree, const TString& className, Double_t weight, const TCut& cut, const TString& treeType );
0116 
0117       // set input variable
0118       void SetInputVariables  ( std::vector<TString>* theVariables ); ///< deprecated
0119 
0120       void AddVariable        ( const TString& expression, const TString& title, const TString& unit,
0121                                 char type='F', Double_t min = 0, Double_t max = 0 );
0122       void AddVariable        ( const TString& expression, char type='F',
0123                                 Double_t min = 0, Double_t max = 0 );
0124 
0125       // NEW: add an array of variables (e.g. for image data) with the provided size
0126       void AddVariablesArray(const TString &expression, int size, char type = 'F',
0127                              Double_t min = 0, Double_t max = 0);
0128 
0129 
0130       void AddTarget          ( const TString& expression, const TString& title = "", const TString& unit = "",
0131                                 Double_t min = 0, Double_t max = 0 );
0132       void AddRegressionTarget( const TString& expression, const TString& title = "", const TString& unit = "",
0133                                 Double_t min = 0, Double_t max = 0 )
0134       {
0135          AddTarget( expression, title, unit, min, max );
0136       }
0137       void AddSpectator         ( const TString& expression, const TString& title = "", const TString& unit = "",
0138                                   Double_t min = 0, Double_t max = 0 );
0139 
0140       // set weight for class
0141       void SetWeightExpression( const TString& variable, const TString& className = "" );
0142 
0143       // set cut for class
0144       void SetCut( const TString& cut, const TString& className = "" );
0145       void SetCut( const TCut& cut, const TString& className = "" );
0146       void AddCut( const TString& cut, const TString& className = "" );
0147       void AddCut( const TCut& cut, const TString& className = "" );
0148 
0149 
0150       //  prepare input tree for training
0151       void PrepareTrainingAndTestTree( const TCut& cut, const TString& splitOpt );
0152       void PrepareTrainingAndTestTree( TCut sigcut, TCut bkgcut, const TString& splitOpt );
0153 
0154       // ... deprecated, kept for backwards compatibility
0155       void PrepareTrainingAndTestTree( const TCut& cut, Int_t Ntrain, Int_t Ntest = -1 );
0156 
0157       void PrepareTrainingAndTestTree( const TCut& cut, Int_t NsigTrain, Int_t NbkgTrain, Int_t NsigTest, Int_t NbkgTest,
0158                                        const TString& otherOpt="SplitMode=Random:!V" );
0159 
0160       // Cross validation
0161       void MakeKFoldDataSet(CvSplit & s);
0162       void PrepareFoldDataSet(CvSplit & s, UInt_t foldNumber, Types::ETreeType tt = Types::kTraining);
0163       void RecombineKFoldDataSet(CvSplit & s, Types::ETreeType tt = Types::kTraining);
0164 
0165       const DataSetInfo& GetDefaultDataSetInfo(){ return DefaultDataSetInfo(); }
0166 
0167       TH2* GetCorrelationMatrix(const TString& className);
0168 
0169       //Copy method use in VI and CV DEPRECATED: you can just call Clone  DataLoader *dl2=(DataLoader *)dl1->Clone("dl2")
0170       DataLoader* MakeCopy(TString name);
0171       friend void DataLoaderCopy(TMVA::DataLoader* des, TMVA::DataLoader* src);
0172       DataInputHandler&        DataInput() { return *fDataInputHandler; }
0173 
0174    private:
0175 
0176 
0177       DataSetInfo&             DefaultDataSetInfo();
0178       void                     SetInputTreesFromEventAssignTrees();
0179 
0180 
0181    private:
0182 
0183       // data members
0184 
0185 
0186       DataSetManager* fDataSetManager; // DSMTEST
0187 
0188 
0189       DataInputHandler*                         fDataInputHandler;  ///<->
0190 
0191       std::vector<TMVA::VariableTransformBase*> fDefaultTrfs;       ///< list of transformations on default DataSet
0192 
0193       // cd to local directory
0194       TString                                   fOptions;           ///< option string given by construction (presently only "V")
0195       TString                                   fTransformations;   ///< List of transformations to test
0196       Bool_t                                    fVerbose;           ///< verbose mode
0197 
0198       // flag determining the way training and test data are assigned to DataLoader
0199       enum DataAssignType { kUndefined = 0,
0200                             kAssignTrees,
0201                             kAssignEvents };
0202       DataAssignType                            fDataAssignType;    ///< flags for data assigning
0203       std::vector<TTree*>                       fTrainAssignTree;   ///<  for each class: tmp tree if user wants to assign the events directly
0204       std::vector<TTree*>                       fTestAssignTree;    ///<  for each class: tmp tree if user wants to assign the events directly
0205 
0206       Int_t                                     fATreeType = 0;     ///<  type of event (=classIndex)
0207       Float_t                                   fATreeWeight = 0.0; ///<  weight of the event
0208       std::vector<Float_t>                      fATreeEvent;        ///<  event variables
0209 
0210       Types::EAnalysisType                      fAnalysisType;      ///<  the training type
0211 
0212    protected:
0213 
0214       ClassDefOverride(DataLoader,4);
0215    };
0216    void DataLoaderCopy(TMVA::DataLoader* des, TMVA::DataLoader* src);
0217 } // namespace TMVA
0218 
0219 #endif
0220