File indexing completed on 2025-01-18 10:11:00
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021
0022
0023
0024
0025
0026
0027
0028
0029
0030
0031 #ifndef ROOT_TMVA_MethodCategory
0032 #define ROOT_TMVA_MethodCategory
0033
0034
0035
0036
0037
0038
0039
0040
0041
0042 #include <iosfwd>
0043 #include <vector>
0044
0045 #include "TMVA/MethodBase.h"
0046
0047 #include "TMVA/MethodCompositeBase.h"
0048
0049 namespace TMVA {
0050
0051 class Factory;
0052 class Reader;
0053 class MethodBoost;
0054 class DataSetManager;
0055 namespace Experimental {
0056 class Classification;
0057 }
0058 class MethodCategory : public MethodCompositeBase {
0059 friend class Experimental::Classification;
0060
0061 public :
0062
0063
0064 MethodCategory( const TString& jobName,
0065 const TString& methodTitle,
0066 DataSetInfo& theData,
0067 const TString& theOption = "" );
0068
0069 MethodCategory( DataSetInfo& dsi,
0070 const TString& theWeightFile );
0071
0072 virtual ~MethodCategory( void );
0073
0074 virtual Bool_t HasAnalysisType( Types::EAnalysisType type, UInt_t numberClasses, UInt_t );
0075
0076 void Train( void );
0077
0078
0079 const Ranking* CreateRanking();
0080
0081
0082 TMVA::IMethod* AddMethod(const TCut&,
0083 const TString& theVariables,
0084 Types::EMVA theMethod,
0085 const TString& theTitle,
0086 const TString& theOptions);
0087
0088 void AddWeightsXMLTo( void* parent ) const;
0089 void ReadWeightsFromXML( void* wghtnode );
0090
0091 Double_t GetMvaValue( Double_t* err = nullptr, Double_t* errUpper = nullptr );
0092
0093
0094 virtual const std::vector<Float_t>& GetRegressionValues();
0095
0096
0097 virtual const std::vector<Float_t> &GetMulticlassValues();
0098
0099 virtual void MakeClass( const TString& = TString("") ) const {};
0100
0101 protected :
0102
0103
0104 virtual std::vector<Double_t> GetMvaValues(Long64_t firstEvt = 0, Long64_t lastEvt = -1, Bool_t logProgress = false);
0105
0106 private:
0107
0108 void Init();
0109
0110
0111 void DeclareOptions();
0112 void ProcessOptions();
0113
0114
0115 Bool_t PassesCut( const Event* ev, UInt_t methodIdx );
0116
0117 protected:
0118
0119
0120 std::vector<IMethod*> fMethods;
0121 std::vector<TCut> fCategoryCuts;
0122 std::vector<UInt_t> fCategorySpecIdx;
0123 std::vector<TString> fVars;
0124 std::vector <std::vector <UInt_t> > fVarMaps;
0125
0126
0127 void GetHelpMessage() const;
0128
0129 TMVA::DataSetInfo& CreateCategoryDSI(const TCut&, const TString&, const TString&);
0130
0131 private:
0132
0133 void InitCircularTree(const DataSetInfo& dsi);
0134
0135 TTree * fCatTree;
0136 std::vector<TTreeFormula*> fCatFormulas;
0137
0138 DataSetManager* fDataSetManager;
0139 friend class Factory;
0140 friend class Reader;
0141 friend class MethodBoost;
0142
0143 ClassDef(MethodCategory,0);
0144 };
0145 }
0146
0147 #endif