File indexing completed on 2025-12-16 10:30:14
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021
0022
0023
0024
0025
0026
0027
0028 #ifndef ROOT_TMVA_DataLoader
0029 #define ROOT_TMVA_DataLoader
0030
0031 #include <vector>
0032 #include "TCut.h"
0033
0034 #include "TMVA/Configurable.h"
0035 #include "TMVA/Types.h"
0036 #include "TMVA/DataSet.h"
0037
0038 class TFile;
0039 class TTree;
0040 class TH2;
0041
0042 namespace TMVA {
0043
0044 class CvSplit;
0045 class DataInputHandler;
0046 class DataSetInfo;
0047 class DataSetManager;
0048 class VariableTransformBase;
0049
0050 class DataLoader : public Configurable {
0051 public:
0052
0053 DataLoader(TString thedlName="default");
0054
0055
0056 virtual ~DataLoader();
0057
0058
0059
0060 void AddSignalTrainingEvent ( const std::vector<Double_t>& event, Double_t weight = 1.0 );
0061 void AddBackgroundTrainingEvent( const std::vector<Double_t>& event, Double_t weight = 1.0 );
0062 void AddSignalTestEvent ( const std::vector<Double_t>& event, Double_t weight = 1.0 );
0063 void AddBackgroundTestEvent ( const std::vector<Double_t>& event, Double_t weight = 1.0 );
0064 void AddTrainingEvent( const TString& className, const std::vector<Double_t>& event, Double_t weight );
0065 void AddTestEvent ( const TString& className, const std::vector<Double_t>& event, Double_t weight );
0066 void AddEvent ( const TString& className, Types::ETreeType tt, const std::vector<Double_t>& event, Double_t weight );
0067 Bool_t UserAssignEvents(UInt_t clIndex);
0068 TTree* CreateEventAssignTrees( const TString& name );
0069
0070 DataSetInfo& AddDataSet( DataSetInfo& );
0071 DataSetInfo& AddDataSet( const TString& );
0072 DataSetInfo& GetDataSetInfo();
0073 DataLoader* VarTransform(TString trafoDefinition);
0074
0075
0076
0077
0078 void SetInputTrees( const TString& signalFileName, const TString& backgroundFileName,
0079 Double_t signalWeight=1.0, Double_t backgroundWeight=1.0 );
0080 void SetInputTrees( TTree* inputTree, const TCut& SigCut, const TCut& BgCut );
0081
0082 void SetInputTrees( TTree* signal, TTree* background,
0083 Double_t signalWeight=1.0, Double_t backgroundWeight=1.0) ;
0084
0085 void AddSignalTree( TTree* signal, Double_t weight=1.0, Types::ETreeType treetype = Types::kMaxTreeType );
0086 void AddSignalTree( TString datFileS, Double_t weight=1.0, Types::ETreeType treetype = Types::kMaxTreeType );
0087 void AddSignalTree( TTree* signal, Double_t weight, const TString& treetype );
0088
0089
0090 void SetSignalTree( TTree* signal, Double_t weight=1.0);
0091
0092 void AddBackgroundTree( TTree* background, Double_t weight=1.0, Types::ETreeType treetype = Types::kMaxTreeType );
0093 void AddBackgroundTree( TString datFileB, Double_t weight=1.0, Types::ETreeType treetype = Types::kMaxTreeType );
0094 void AddBackgroundTree( TTree* background, Double_t weight, const TString & treetype );
0095
0096
0097 void SetBackgroundTree( TTree* background, Double_t weight=1.0 );
0098
0099 void SetSignalWeightExpression( const TString& variable );
0100 void SetBackgroundWeightExpression( const TString& variable );
0101
0102
0103 void AddRegressionTree( TTree* tree, Double_t weight = 1.0,
0104 Types::ETreeType treetype = Types::kMaxTreeType ) {
0105 AddTree( tree, "Regression", weight, "", treetype );
0106 }
0107
0108
0109
0110
0111 void SetTree( TTree* tree, const TString& className, Double_t weight );
0112 void AddTree( TTree* tree, const TString& className, Double_t weight=1.0,
0113 const TCut& cut = "",
0114 Types::ETreeType tt = Types::kMaxTreeType );
0115 void AddTree( TTree* tree, const TString& className, Double_t weight, const TCut& cut, const TString& treeType );
0116
0117
0118 void SetInputVariables ( std::vector<TString>* theVariables );
0119
0120 void AddVariable ( const TString& expression, const TString& title, const TString& unit,
0121 char type='F', Double_t min = 0, Double_t max = 0 );
0122 void AddVariable ( const TString& expression, char type='F',
0123 Double_t min = 0, Double_t max = 0 );
0124
0125
0126 void AddVariablesArray(const TString &expression, int size, char type = 'F',
0127 Double_t min = 0, Double_t max = 0);
0128
0129
0130 void AddTarget ( const TString& expression, const TString& title = "", const TString& unit = "",
0131 Double_t min = 0, Double_t max = 0 );
0132 void AddRegressionTarget( const TString& expression, const TString& title = "", const TString& unit = "",
0133 Double_t min = 0, Double_t max = 0 )
0134 {
0135 AddTarget( expression, title, unit, min, max );
0136 }
0137 void AddSpectator ( const TString& expression, const TString& title = "", const TString& unit = "",
0138 Double_t min = 0, Double_t max = 0 );
0139
0140
0141 void SetWeightExpression( const TString& variable, const TString& className = "" );
0142
0143
0144 void SetCut( const TString& cut, const TString& className = "" );
0145 void SetCut( const TCut& cut, const TString& className = "" );
0146 void AddCut( const TString& cut, const TString& className = "" );
0147 void AddCut( const TCut& cut, const TString& className = "" );
0148
0149
0150
0151 void PrepareTrainingAndTestTree( const TCut& cut, const TString& splitOpt );
0152 void PrepareTrainingAndTestTree( TCut sigcut, TCut bkgcut, const TString& splitOpt );
0153
0154
0155 void PrepareTrainingAndTestTree( const TCut& cut, Int_t Ntrain, Int_t Ntest = -1 );
0156
0157 void PrepareTrainingAndTestTree( const TCut& cut, Int_t NsigTrain, Int_t NbkgTrain, Int_t NsigTest, Int_t NbkgTest,
0158 const TString& otherOpt="SplitMode=Random:!V" );
0159
0160
0161 void MakeKFoldDataSet(CvSplit & s);
0162 void PrepareFoldDataSet(CvSplit & s, UInt_t foldNumber, Types::ETreeType tt = Types::kTraining);
0163 void RecombineKFoldDataSet(CvSplit & s, Types::ETreeType tt = Types::kTraining);
0164
0165 const DataSetInfo& GetDefaultDataSetInfo(){ return DefaultDataSetInfo(); }
0166
0167 TH2* GetCorrelationMatrix(const TString& className);
0168
0169
0170 DataLoader* MakeCopy(TString name);
0171 friend void DataLoaderCopy(TMVA::DataLoader* des, TMVA::DataLoader* src);
0172 DataInputHandler& DataInput() { return *fDataInputHandler; }
0173
0174 private:
0175
0176
0177 DataSetInfo& DefaultDataSetInfo();
0178 void SetInputTreesFromEventAssignTrees();
0179
0180
0181 private:
0182
0183
0184
0185
0186 DataSetManager* fDataSetManager;
0187
0188
0189 DataInputHandler* fDataInputHandler;
0190
0191 std::vector<TMVA::VariableTransformBase*> fDefaultTrfs;
0192
0193
0194 TString fOptions;
0195 TString fTransformations;
0196 Bool_t fVerbose;
0197
0198
0199 enum DataAssignType { kUndefined = 0,
0200 kAssignTrees,
0201 kAssignEvents };
0202 DataAssignType fDataAssignType;
0203 std::vector<TTree*> fTrainAssignTree;
0204 std::vector<TTree*> fTestAssignTree;
0205
0206 Int_t fATreeType = 0;
0207 Float_t fATreeWeight = 0.0;
0208 std::vector<Float_t> fATreeEvent;
0209
0210 Types::EAnalysisType fAnalysisType;
0211
0212 protected:
0213
0214 ClassDefOverride(DataLoader,4);
0215 };
0216 void DataLoaderCopy(TMVA::DataLoader* des, TMVA::DataLoader* src);
0217 }
0218
0219 #endif
0220