File indexing completed on 2025-01-18 10:10:58
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021
0022
0023
0024
0025
0026
0027
0028
0029 #ifndef ROOT_TMVA_DataSetFactory
0030 #define ROOT_TMVA_DataSetFactory
0031
0032
0033
0034
0035
0036
0037
0038
0039
0040 #include <vector>
0041 #include <map>
0042
0043 #include "TString.h"
0044 #include "TTree.h"
0045 #include "TCut.h"
0046 #include "TTreeFormula.h"
0047 #include "TMatrixDfwd.h"
0048 #include "TPrincipal.h"
0049 #include "TRandom3.h"
0050
0051 #include "TMVA/Types.h"
0052 #include "TMVA/VariableInfo.h"
0053 #include "TMVA/Event.h"
0054
0055 namespace TMVA {
0056
0057 class DataSet;
0058 class DataSetInfo;
0059 class DataInputHandler;
0060 class TreeInfo;
0061 class MsgLogger;
0062
0063
0064
0065
0066
0067
0068 template<class T>
0069 struct DeleteFunctor_t
0070 {
0071 DeleteFunctor_t& operator()(const T* p) {
0072 delete p;
0073 return *this;
0074 }
0075 };
0076
0077 template<class T>
0078 DeleteFunctor_t<const T> DeleteFunctor()
0079 {
0080 return DeleteFunctor_t<const T>();
0081 }
0082
0083
0084 template< typename T >
0085 class Increment {
0086 T value;
0087 public:
0088 Increment( T start ) : value( start ){ }
0089 T operator()() {
0090 return value++;
0091 }
0092 };
0093
0094
0095
0096 template <typename F>
0097 class null_t
0098 {
0099 private:
0100
0101 public:
0102 typedef F argument_type;
0103 F operator()(const F& argF) const
0104 {
0105 return argF;
0106 }
0107 };
0108
0109 template <typename F>
0110 inline null_t<F> null() {
0111 return null_t<F>();
0112 }
0113
0114
0115
0116 class DataSetFactory:public TObject {
0117
0118 typedef std::vector<Event* > EventVector;
0119 typedef std::vector< EventVector > EventVectorOfClasses;
0120 typedef std::map<Types::ETreeType, EventVectorOfClasses > EventVectorOfClassesOfTreeType;
0121 typedef std::map<Types::ETreeType, EventVector > EventVectorOfTreeType;
0122
0123 typedef std::vector< Double_t > ValuePerClass;
0124 typedef std::map<Types::ETreeType, ValuePerClass > ValuePerClassOfTreeType;
0125
0126 class EventStats {
0127 public:
0128 Int_t nTrainingEventsRequested;
0129 Int_t nTestingEventsRequested;
0130 Float_t TrainTestSplitRequested;
0131 Int_t nInitialEvents;
0132 Int_t nEvBeforeCut;
0133 Int_t nEvAfterCut;
0134 Float_t nWeEvBeforeCut;
0135 Float_t nWeEvAfterCut;
0136 Double_t nNegWeights;
0137 Float_t* varAvLength;
0138 EventStats():
0139 nTrainingEventsRequested(0),
0140 nTestingEventsRequested(0),
0141 TrainTestSplitRequested(0),
0142 nInitialEvents(0),
0143 nEvBeforeCut(0),
0144 nEvAfterCut(0),
0145 nWeEvBeforeCut(0),
0146 nWeEvAfterCut(0),
0147 nNegWeights(0),
0148 varAvLength(nullptr)
0149 {}
0150 ~EventStats() { delete[] varAvLength; }
0151 Float_t cutScaling() const { return Float_t(nEvAfterCut)/nEvBeforeCut; }
0152 };
0153
0154 typedef std::vector< int > NumberPerClass;
0155 typedef std::vector< EventStats > EvtStatsPerClass;
0156
0157 public:
0158
0159 ~DataSetFactory();
0160
0161 DataSetFactory();
0162
0163 DataSet* CreateDataSet( DataSetInfo &, DataInputHandler& );
0164 protected:
0165
0166
0167 DataSet* BuildInitialDataSet( DataSetInfo&, TMVA::DataInputHandler& );
0168 DataSet* BuildDynamicDataSet( DataSetInfo& );
0169
0170
0171 void BuildEventVector ( DataSetInfo& dsi,
0172 DataInputHandler& dataInput,
0173 EventVectorOfClassesOfTreeType& eventsmap,
0174 EvtStatsPerClass& eventCounts);
0175
0176 DataSet* MixEvents ( DataSetInfo& dsi,
0177 EventVectorOfClassesOfTreeType& eventsmap,
0178 EvtStatsPerClass& eventCounts,
0179 const TString& splitMode,
0180 const TString& mixMode,
0181 const TString& normMode,
0182 UInt_t splitSeed);
0183
0184 void RenormEvents ( DataSetInfo& dsi,
0185 EventVectorOfClassesOfTreeType& eventsmap,
0186 const EvtStatsPerClass& eventCounts,
0187 const TString& normMode );
0188
0189 void InitOptions ( DataSetInfo& dsi,
0190 EvtStatsPerClass& eventsmap,
0191 TString& normMode, UInt_t& splitSeed,
0192 TString& splitMode, TString& mixMode);
0193
0194
0195
0196
0197
0198 TMatrixD* CalcCorrelationMatrix( DataSet*, const UInt_t classNumber );
0199 TMatrixD* CalcCovarianceMatrix ( DataSet*, const UInt_t classNumber );
0200 void CalcMinMax ( DataSet*, DataSetInfo& dsi );
0201
0202
0203 void ResetBranchAndEventAddresses( TTree* );
0204 void ResetCurrentTree() { fCurrentTree = nullptr; }
0205 void ChangeToNewTree( TreeInfo&, const DataSetInfo & );
0206 Bool_t CheckTTreeFormula( TTreeFormula* ttf, const TString& expression, Bool_t& hasDollar );
0207
0208
0209 Bool_t Verbose() { return fVerbose; }
0210
0211
0212
0213
0214 Bool_t fVerbose;
0215 TString fVerboseLevel;
0216
0217
0218 Bool_t fCorrelations = kFALSE;
0219 Bool_t fComputeCorrelations = kFALSE;
0220
0221 Bool_t fScaleWithPreselEff;
0222
0223
0224 TTree* fCurrentTree;
0225 UInt_t fCurrentEvtIdx;
0226
0227
0228 std::vector<TTreeFormula*> fInputFormulas;
0229 std::vector<std::pair<TTreeFormula*, Int_t>> fInputTableFormulas;
0230 std::vector<TTreeFormula *> fTargetFormulas;
0231 std::vector<TTreeFormula*> fCutFormulas;
0232 std::vector<TTreeFormula*> fWeightFormula;
0233 std::vector<TTreeFormula*> fSpectatorFormulas;
0234
0235 MsgLogger* fLogger;
0236 MsgLogger& Log() const { return *fLogger; }
0237 public:
0238 ClassDef(DataSetFactory, 2);
0239 };
0240 }
0241
0242 #endif