File indexing completed on 2025-01-30 10:22:51
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021
0022
0023
0024
0025
0026
0027 #ifndef ROOT_TMVA_DataSetInfo
0028 #define ROOT_TMVA_DataSetInfo
0029
0030
0031
0032
0033
0034
0035
0036
0037
0038 #include <iosfwd>
0039 #include <vector>
0040 #include <map>
0041
0042 #include "TObject.h"
0043 #include "TString.h"
0044 #include "TTree.h"
0045 #include "TCut.h"
0046 #include "TMatrixDfwd.h"
0047
0048 #include "TMVA/Types.h"
0049 #include "TMVA/VariableInfo.h"
0050 #include "TMVA/ClassInfo.h"
0051 #include "TMVA/Event.h"
0052
0053 class TH2;
0054
0055 namespace TMVA {
0056
0057 class DataSet;
0058 class VariableTransformBase;
0059 class MsgLogger;
0060 class DataSetManager;
0061
0062 class DataSetInfo : public TObject {
0063
0064 public:
0065
0066 enum { kIsArrayVariable = BIT(15) };
0067
0068 DataSetInfo(const TString& name = "Default");
0069 virtual ~DataSetInfo();
0070
0071 virtual const char* GetName() const { return fName.Data(); }
0072
0073
0074 void ClearDataSet() const;
0075 DataSet* GetDataSet() const;
0076
0077
0078
0079
0080 VariableInfo& AddVariable( const TString& expression, const TString& title = "", const TString& unit = "",
0081 Double_t min = 0, Double_t max = 0, char varType='F',
0082 Bool_t normalized = kTRUE, void* external = nullptr );
0083 VariableInfo& AddVariable( const VariableInfo& varInfo );
0084
0085
0086 void AddVariablesArray(const TString &expression, Int_t size, const TString &title = "", const TString &unit = "",
0087 Double_t min = 0, Double_t max = 0, char type = 'F', Bool_t normalized = kTRUE,
0088 void *external = nullptr );
0089
0090 VariableInfo& AddTarget ( const TString& expression, const TString& title, const TString& unit,
0091 Double_t min, Double_t max, Bool_t normalized = kTRUE, void* external = nullptr );
0092 VariableInfo& AddTarget ( const VariableInfo& varInfo );
0093
0094 VariableInfo& AddSpectator ( const TString& expression, const TString& title, const TString& unit,
0095 Double_t min, Double_t max, char type = 'F', Bool_t normalized = kTRUE, void* external = nullptr );
0096 VariableInfo& AddSpectator ( const VariableInfo& varInfo );
0097
0098 ClassInfo* AddClass ( const TString& className );
0099
0100
0101
0102
0103 std::vector<VariableInfo>& GetVariableInfos() { return fVariables; }
0104 const std::vector<VariableInfo>& GetVariableInfos() const { return fVariables; }
0105 VariableInfo& GetVariableInfo( Int_t i ) { return fVariables.at(i); }
0106 const VariableInfo& GetVariableInfo( Int_t i ) const { return fVariables.at(i); }
0107
0108 Int_t GetVarArraySize(const TString &expression) const {
0109 auto element = fVarArrays.find(expression);
0110 return (element != fVarArrays.end()) ? element->second : -1;
0111 }
0112 Bool_t IsVariableFromArray(Int_t i) const { return GetVariableInfo(i).TestBit(DataSetInfo::kIsArrayVariable); }
0113
0114 std::vector<VariableInfo> &GetTargetInfos()
0115 {
0116 return fTargets;
0117 }
0118 const std::vector<VariableInfo> &GetTargetInfos() const { return fTargets; }
0119 VariableInfo &GetTargetInfo(Int_t i) { return fTargets.at(i); }
0120 const VariableInfo &GetTargetInfo(Int_t i) const { return fTargets.at(i); }
0121
0122 std::vector<VariableInfo> &GetSpectatorInfos() { return fSpectators; }
0123 const std::vector<VariableInfo> &GetSpectatorInfos() const { return fSpectators; }
0124 VariableInfo &GetSpectatorInfo(Int_t i) { return fSpectators.at(i); }
0125 const VariableInfo &GetSpectatorInfo(Int_t i) const { return fSpectators.at(i); }
0126
0127 UInt_t GetNVariables() const { return fVariables.size(); }
0128 UInt_t GetNTargets() const { return fTargets.size(); }
0129 UInt_t GetNSpectators(bool all = kTRUE) const;
0130
0131 const TString &GetNormalization() const { return fNormalization; }
0132 void SetNormalization(const TString &norm) { fNormalization = norm; }
0133
0134 void SetTrainingSumSignalWeights(Double_t trainingSumSignalWeights)
0135 {
0136 fTrainingSumSignalWeights = trainingSumSignalWeights;}
0137 void SetTrainingSumBackgrWeights(Double_t trainingSumBackgrWeights){fTrainingSumBackgrWeights = trainingSumBackgrWeights;}
0138 void SetTestingSumSignalWeights (Double_t testingSumSignalWeights ){fTestingSumSignalWeights = testingSumSignalWeights ;}
0139 void SetTestingSumBackgrWeights (Double_t testingSumBackgrWeights ){fTestingSumBackgrWeights = testingSumBackgrWeights ;}
0140
0141 Double_t GetTrainingSumSignalWeights();
0142 Double_t GetTrainingSumBackgrWeights();
0143 Double_t GetTestingSumSignalWeights ();
0144 Double_t GetTestingSumBackgrWeights ();
0145
0146
0147
0148
0149 Int_t GetClassNameMaxLength() const;
0150 Int_t GetVariableNameMaxLength() const;
0151 Int_t GetTargetNameMaxLength() const;
0152 ClassInfo* GetClassInfo( Int_t clNum ) const;
0153 ClassInfo* GetClassInfo( const TString& name ) const;
0154 void PrintClasses() const;
0155 UInt_t GetNClasses() const { return fClasses.size(); }
0156 Bool_t IsSignal( const Event* ev ) const;
0157 std::vector<Float_t>* GetTargetsForMulticlass( const Event* ev );
0158 UInt_t GetSignalClassIndex(){return fSignalClass;}
0159
0160
0161 Int_t FindVarIndex( const TString& ) const;
0162
0163
0164 const TString GetWeightExpression(Int_t i) const { return GetClassInfo(i)->GetWeight(); }
0165 void SetWeightExpression( const TString& exp, const TString& className = "" );
0166
0167
0168 const TCut& GetCut (Int_t i) const { return GetClassInfo(i)->GetCut(); }
0169 const TCut& GetCut ( const TString& className ) const { return GetClassInfo(className)->GetCut(); }
0170 void SetCut ( const TCut& cut, const TString& className );
0171 void AddCut ( const TCut& cut, const TString& className );
0172 Bool_t HasCuts() const;
0173
0174 std::vector<TString> GetListOfVariables() const;
0175
0176
0177 const TMatrixD* CorrelationMatrix ( const TString& className ) const;
0178 void SetCorrelationMatrix ( const TString& className, TMatrixD* matrix );
0179 void PrintCorrelationMatrix( const TString& className );
0180 TH2* CreateCorrelationMatrixHist( const TMatrixD* m,
0181 const TString& hName,
0182 const TString& hTitle ) const;
0183
0184
0185 void SetSplitOptions(const TString& so) { fSplitOptions = so; fNeedsRebuilding = kTRUE; }
0186 const TString& GetSplitOptions() const { return fSplitOptions; }
0187
0188
0189 void SetRootDir(TDirectory* d) { fOwnRootDir = d; }
0190 TDirectory* GetRootDir() const { return fOwnRootDir; }
0191
0192 void SetMsgType( EMsgType t ) const;
0193
0194 DataSetManager* GetDataSetManager(){return fDataSetManager;}
0195 private:
0196
0197 TMVA::DataSetManager* fDataSetManager;
0198 void SetDataSetManager( DataSetManager* dsm ) { fDataSetManager = dsm; }
0199 friend class DataSetManager;
0200
0201 DataSetInfo(const DataSetInfo &) = delete;
0202 DataSetInfo & operator= (const DataSetInfo &) = delete;
0203
0204 void PrintCorrelationMatrix( TTree* theTree );
0205
0206 TString fName;
0207
0208 mutable DataSet* fDataSet;
0209 mutable Bool_t fNeedsRebuilding;
0210
0211
0212 std::vector<VariableInfo> fVariables;
0213 std::vector<VariableInfo> fTargets;
0214 std::vector<VariableInfo> fSpectators;
0215
0216
0217 std::map<TString, int> fVarArrays;
0218
0219
0220 mutable std::vector<ClassInfo*> fClasses;
0221
0222 TString fNormalization;
0223 TString fSplitOptions;
0224
0225 Double_t fTrainingSumSignalWeights;
0226 Double_t fTrainingSumBackgrWeights;
0227 Double_t fTestingSumSignalWeights ;
0228 Double_t fTestingSumBackgrWeights ;
0229
0230
0231
0232 TDirectory* fOwnRootDir;
0233 Bool_t fVerbose;
0234
0235 UInt_t fSignalClass;
0236
0237 std::vector<Float_t>* fTargetsForMulticlass;
0238
0239 mutable MsgLogger* fLogger;
0240 MsgLogger& Log() const { return *fLogger; }
0241
0242 public:
0243
0244 ClassDef(DataSetInfo,1);
0245 };
0246 }
0247
0248 #endif