File indexing completed on 2025-01-30 10:22:50
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021
0022
0023
0024
0025
0026
0027
0028
0029 #ifndef ROOT_TMVA_DataSet
0030 #define ROOT_TMVA_DataSet
0031
0032
0033
0034
0035
0036
0037
0038
0039
0040 #include <vector>
0041 #include <map>
0042
0043 #include "TNamed.h"
0044 #include "TString.h"
0045 #include "TTree.h"
0046 #include "TRandom3.h"
0047
0048 #include "TMVA/Types.h"
0049 #include "TMVA/VariableInfo.h"
0050
0051 namespace TMVA {
0052
0053 class Event;
0054 class DataSetInfo;
0055 class MsgLogger;
0056 class Results;
0057
0058 class DataSet :public TNamed {
0059
0060 public:
0061 DataSet();
0062 DataSet(const DataSetInfo&);
0063 virtual ~DataSet();
0064
0065 void AddEvent( Event *, Types::ETreeType );
0066
0067 Long64_t GetNEvents( Types::ETreeType type = Types::kMaxTreeType ) const;
0068 Long64_t GetNTrainingEvents() const { return GetNEvents(Types::kTraining); }
0069 Long64_t GetNTestEvents() const { return GetNEvents(Types::kTesting); }
0070
0071
0072 const Event* GetEvent() const;
0073 const Event* GetEvent ( Long64_t ievt ) const { fCurrentEventIdx = ievt; return GetEvent(); }
0074 const Event* GetTrainingEvent( Long64_t ievt ) const { return GetEvent(ievt, Types::kTraining); }
0075 const Event* GetTestEvent ( Long64_t ievt ) const { return GetEvent(ievt, Types::kTesting); }
0076 const Event* GetEvent ( Long64_t ievt, Types::ETreeType type ) const
0077 {
0078 fCurrentTreeIdx = TreeIndex(type); fCurrentEventIdx = ievt; return GetEvent();
0079 }
0080
0081
0082
0083
0084 UInt_t GetNVariables() const;
0085 UInt_t GetNTargets() const;
0086 UInt_t GetNSpectators() const;
0087
0088 void SetCurrentEvent( Long64_t ievt ) const { fCurrentEventIdx = ievt; }
0089 void SetCurrentType ( Types::ETreeType type ) const { fCurrentTreeIdx = TreeIndex(type); }
0090 Types::ETreeType GetCurrentType() const;
0091
0092 void SetEventCollection( std::vector<Event*>*, Types::ETreeType, Bool_t deleteEvents = true );
0093 const std::vector<Event*>& GetEventCollection( Types::ETreeType type = Types::kMaxTreeType ) const;
0094 const TTree* GetEventCollectionAsTree();
0095
0096 Long64_t GetNEvtSigTest();
0097 Long64_t GetNEvtBkgdTest();
0098 Long64_t GetNEvtSigTrain();
0099 Long64_t GetNEvtBkgdTrain();
0100
0101 Bool_t HasNegativeEventWeights() const { return fHasNegativeEventWeights; }
0102
0103 Results* GetResults ( const TString &,
0104 Types::ETreeType type,
0105 Types::EAnalysisType analysistype );
0106 void DeleteResults ( const TString &,
0107 Types::ETreeType type,
0108 Types::EAnalysisType analysistype );
0109 void DeleteAllResults(Types::ETreeType type,
0110 Types::EAnalysisType analysistype);
0111
0112 void SetVerbose( Bool_t ) {}
0113
0114
0115
0116 void DivideTrainingSet( UInt_t blockNum );
0117
0118
0119 void MoveTrainingBlock( Int_t blockInd,Types::ETreeType dest, Bool_t applyChanges = kTRUE );
0120
0121 void IncrementNClassEvents( Int_t type, UInt_t classNumber );
0122 Long64_t GetNClassEvents ( Int_t type, UInt_t classNumber );
0123 void ClearNClassEvents ( Int_t type );
0124
0125 TTree* GetTree( Types::ETreeType type );
0126
0127
0128 void InitSampling( Float_t fraction, Float_t weight, UInt_t seed = 0 );
0129 void EventResult( Bool_t successful, Long64_t evtNumber = -1 );
0130 void CreateSampling() const;
0131
0132 UInt_t TreeIndex(Types::ETreeType type) const;
0133
0134 private:
0135
0136
0137 void DestroyCollection( Types::ETreeType type, Bool_t deleteEvents );
0138
0139 const DataSetInfo *fdsi;
0140
0141 std::vector< std::vector<Event*> > fEventCollection;
0142
0143 std::vector< std::map< TString, Results* > > fResults;
0144
0145 mutable UInt_t fCurrentTreeIdx;
0146 mutable Long64_t fCurrentEventIdx;
0147
0148
0149 std::vector<Char_t> fSampling;
0150 std::vector<Int_t> fSamplingNEvents;
0151 std::vector<Float_t> fSamplingWeight;
0152 mutable std::vector< std::vector< std::pair< Float_t, Long64_t > > > fSamplingEventList;
0153 mutable std::vector< std::vector< std::pair< Float_t, Long64_t > > > fSamplingSelected;
0154 TRandom3 *fSamplingRandom;
0155
0156
0157
0158 std::vector< std::vector<Long64_t> > fClassEvents;
0159
0160
0161 Bool_t fHasNegativeEventWeights;
0162
0163 mutable MsgLogger* fLogger;
0164 MsgLogger& Log() const { return *fLogger; }
0165 std::vector<Char_t> fBlockBelongToTraining;
0166
0167
0168
0169 Long64_t fTrainingBlockSize;
0170
0171 void ApplyTrainingBlockDivision();
0172 void ApplyTrainingSetDivision();
0173 public:
0174
0175 ClassDef(DataSet,1);
0176 };
0177 }
0178
0179
0180
0181 inline UInt_t TMVA::DataSet::TreeIndex(Types::ETreeType type) const
0182 {
0183 switch (type) {
0184 case Types::kMaxTreeType : return fCurrentTreeIdx;
0185 case Types::kTraining : return 0;
0186 case Types::kTesting : return 1;
0187 case Types::kValidation : return 2;
0188 case Types::kTrainingOriginal : return 3;
0189 default : return fCurrentTreeIdx;
0190 }
0191 }
0192
0193
0194 inline TMVA::Types::ETreeType TMVA::DataSet::GetCurrentType() const
0195 {
0196 switch (fCurrentTreeIdx) {
0197 case 0: return Types::kTraining;
0198 case 1: return Types::kTesting;
0199 case 2: return Types::kValidation;
0200 case 3: return Types::kTrainingOriginal;
0201 }
0202 return Types::kMaxTreeType;
0203 }
0204
0205
0206 inline Long64_t TMVA::DataSet::GetNEvents(Types::ETreeType type) const
0207 {
0208 Int_t treeIdx = TreeIndex(type);
0209 if (fSampling.size() > UInt_t(treeIdx) && fSampling.at(treeIdx)) {
0210 return fSamplingSelected.at(treeIdx).size();
0211 }
0212 return GetEventCollection(type).size();
0213 }
0214
0215
0216 inline const std::vector<TMVA::Event*>& TMVA::DataSet::GetEventCollection( TMVA::Types::ETreeType type ) const
0217 {
0218 return fEventCollection.at(TreeIndex(type));
0219 }
0220
0221
0222 #endif