Warning, file /include/root/TMVA/DataSetFactory.h was not indexed
or was modified since last indexation (in which case cross-reference links may be missing, inaccurate or erroneous).
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021
0022
0023
0024
0025
0026
0027
0028
0029 #ifndef ROOT_TMVA_DataSetFactory
0030 #define ROOT_TMVA_DataSetFactory
0031
0032
0033
0034
0035
0036
0037
0038
0039
0040 #include <vector>
0041 #include <map>
0042
0043 #include "TString.h"
0044 #include "TTree.h"
0045 #include "TCut.h"
0046 #include "TTreeFormula.h"
0047 #include "TMatrixDfwd.h"
0048 #include "TPrincipal.h"
0049 #include "TRandom3.h"
0050
0051 #include "TMVA/Types.h"
0052 #include "TMVA/VariableInfo.h"
0053 #include "TMVA/Event.h"
0054
0055 namespace TMVA {
0056
0057 class DataSet;
0058 class DataSetInfo;
0059 class DataInputHandler;
0060 class TreeInfo;
0061 class MsgLogger;
0062
0063
0064
0065
0066
0067
0068 template<class T>
0069 struct DeleteFunctor_t
0070 {
0071 DeleteFunctor_t& operator()(const T* p) {
0072 delete p;
0073 return *this;
0074 }
0075 };
0076
0077 template<class T>
0078 DeleteFunctor_t<const T> DeleteFunctor()
0079 {
0080 return DeleteFunctor_t<const T>();
0081 }
0082
0083
0084 template< typename T >
0085 class Increment {
0086 T value;
0087 public:
0088 Increment( T start ) : value( start ){ }
0089 T operator()() {
0090 return value++;
0091 }
0092 };
0093
0094
0095
0096 template <typename F>
0097 class null_t
0098 {
0099 private:
0100
0101 public:
0102 typedef F argument_type;
0103 F operator()(const F& argF) const
0104 {
0105 return argF;
0106 }
0107 };
0108
0109 template <typename F>
0110 inline null_t<F> null() {
0111 return null_t<F>();
0112 }
0113
0114
0115
0116 class DataSetFactory:public TObject {
0117
0118 typedef std::vector<Event* > EventVector;
0119 typedef std::vector< EventVector > EventVectorOfClasses;
0120 typedef std::map<Types::ETreeType, EventVectorOfClasses > EventVectorOfClassesOfTreeType;
0121 typedef std::map<Types::ETreeType, EventVector > EventVectorOfTreeType;
0122
0123 typedef std::vector< Double_t > ValuePerClass;
0124 typedef std::map<Types::ETreeType, ValuePerClass > ValuePerClassOfTreeType;
0125
0126 class EventStats {
0127 public:
0128 Int_t nTrainingEventsRequested;
0129 Int_t nTestingEventsRequested;
0130 Float_t TrainTestSplitRequested;
0131 Int_t nInitialEvents;
0132 Int_t nEvBeforeCut;
0133 Int_t nEvAfterCut;
0134 Float_t nWeEvBeforeCut;
0135 Float_t nWeEvAfterCut;
0136 Double_t nNegWeights;
0137 Float_t* varAvLength;
0138 EventStats():
0139 nTrainingEventsRequested(0),
0140 nTestingEventsRequested(0),
0141 TrainTestSplitRequested(0),
0142 nInitialEvents(0),
0143 nEvBeforeCut(0),
0144 nEvAfterCut(0),
0145 nWeEvBeforeCut(0),
0146 nWeEvAfterCut(0),
0147 nNegWeights(0),
0148 varAvLength(nullptr)
0149 {}
0150 ~EventStats() { delete[] varAvLength; }
0151 Float_t cutScaling() const { return Float_t(nEvAfterCut)/nEvBeforeCut; }
0152 };
0153
0154 typedef std::vector< int > NumberPerClass;
0155 typedef std::vector< EventStats > EvtStatsPerClass;
0156
0157 public:
0158
0159 ~DataSetFactory();
0160
0161 DataSetFactory();
0162
0163 DataSet* CreateDataSet( DataSetInfo &, DataInputHandler& );
0164 protected:
0165
0166
0167 DataSet* BuildInitialDataSet( DataSetInfo&, TMVA::DataInputHandler& );
0168 DataSet* BuildDynamicDataSet( DataSetInfo& );
0169
0170
0171 void BuildEventVector ( DataSetInfo& dsi,
0172 DataInputHandler& dataInput,
0173 EventVectorOfClassesOfTreeType& eventsmap,
0174 EvtStatsPerClass& eventCounts);
0175
0176 DataSet* MixEvents ( DataSetInfo& dsi,
0177 EventVectorOfClassesOfTreeType& eventsmap,
0178 EvtStatsPerClass& eventCounts,
0179 const TString& splitMode,
0180 const TString& mixMode,
0181 const TString& normMode,
0182 UInt_t splitSeed);
0183
0184 void RenormEvents ( DataSetInfo& dsi,
0185 EventVectorOfClassesOfTreeType& eventsmap,
0186 const EvtStatsPerClass& eventCounts,
0187 const TString& normMode );
0188
0189 void InitOptions ( DataSetInfo& dsi,
0190 EvtStatsPerClass& eventsmap,
0191 TString& normMode, UInt_t& splitSeed,
0192 TString& splitMode, TString& mixMode);
0193
0194
0195
0196
0197
0198 TMatrixD* CalcCorrelationMatrix( DataSet*, const UInt_t classNumber );
0199 TMatrixD* CalcCovarianceMatrix ( DataSet*, const UInt_t classNumber );
0200 void CalcMinMax ( DataSet*, DataSetInfo& dsi );
0201
0202
0203 void ResetBranchAndEventAddresses( TTree* );
0204 void ResetCurrentTree() { fCurrentTree = nullptr; }
0205 void ChangeToNewTree( TreeInfo&, const DataSetInfo & );
0206 Bool_t CheckTTreeFormula( TTreeFormula* ttf, const TString& expression, Bool_t& hasDollar );
0207
0208
0209 Bool_t Verbose() { return fVerbose; }
0210
0211
0212
0213
0214 Bool_t fVerbose;
0215 TString fVerboseLevel;
0216
0217
0218 Bool_t fCorrelations = kFALSE;
0219 Bool_t fComputeCorrelations = kFALSE;
0220
0221 Bool_t fScaleWithPreselEff;
0222
0223
0224 TTree* fCurrentTree;
0225 UInt_t fCurrentEvtIdx;
0226
0227
0228 std::vector<TTreeFormula*> fInputFormulas;
0229 std::vector<std::pair<TTreeFormula*, Int_t>> fInputTableFormulas;
0230 std::vector<TTreeFormula *> fTargetFormulas;
0231 std::vector<TTreeFormula*> fCutFormulas;
0232 std::vector<TTreeFormula*> fWeightFormula;
0233 std::vector<TTreeFormula*> fSpectatorFormulas;
0234
0235 MsgLogger* fLogger;
0236 MsgLogger& Log() const { return *fLogger; }
0237 public:
0238 ClassDefOverride(DataSetFactory, 2);
0239 };
0240 }
0241
0242 #endif