File indexing completed on 2025-01-18 10:11:00
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012 #ifndef ROOT_TMVA_MethodCrossValidation
0013 #define ROOT_TMVA_MethodCrossValidation
0014
0015
0016
0017
0018
0019
0020
0021 #include "TMVA/CvSplit.h"
0022 #include "TMVA/DataSetInfo.h"
0023 #include "TMVA/MethodBase.h"
0024
0025 #include "TString.h"
0026
0027 #include <iostream>
0028 #include <memory>
0029 #include <vector>
0030 #include <map>
0031
0032 namespace TMVA {
0033
0034 class CrossValidation;
0035 class Ranking;
0036
0037
0038 class MethodCrossValidation : public MethodBase {
0039
0040 friend CrossValidation;
0041
0042 public:
0043
0044 MethodCrossValidation(const TString &jobName, const TString &methodTitle, DataSetInfo &theData,
0045 const TString &theOption = "");
0046
0047
0048 MethodCrossValidation(DataSetInfo &theData, const TString &theWeightFile);
0049
0050 virtual ~MethodCrossValidation(void);
0051
0052
0053
0054
0055
0056
0057 void Train(void);
0058
0059
0060 void Reset(void);
0061
0062 using MethodBase::ReadWeightsFromStream;
0063
0064
0065 void AddWeightsXMLTo(void *parent) const;
0066
0067
0068 void ReadWeightsFromStream(std::istream &istr);
0069 void ReadWeightsFromXML(void *parent);
0070
0071
0072 void WriteMonitoringHistosToFile(void) const;
0073
0074
0075 Double_t GetMvaValue(Double_t *err = nullptr, Double_t *errUpper = nullptr);
0076 const std::vector<Float_t> &GetMulticlassValues();
0077 const std::vector<Float_t> &GetRegressionValues();
0078
0079
0080 void DeclareOptions();
0081 void ProcessOptions();
0082
0083
0084 void MakeClassSpecific(std::ostream &, const TString &) const;
0085 void MakeClassSpecificHeader(std::ostream &, const TString &) const;
0086
0087 void GetHelpMessage() const;
0088
0089 const Ranking *CreateRanking();
0090 Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets);
0091
0092 protected:
0093 void Init(void);
0094 void DeclareCompatibilityOptions();
0095
0096 private:
0097 TString GetWeightFileNameForFold(UInt_t iFold) const;
0098 MethodBase *InstantiateMethodFromXML(TString methodTypeName, TString weightfile) const;
0099
0100 private:
0101 TString fEncapsulatedMethodName;
0102 TString fEncapsulatedMethodTypeName;
0103 UInt_t fNumFolds;
0104 TString fOutputEnsembling;
0105
0106 TString fSplitExprString;
0107 std::unique_ptr<CvSplitKFoldsExpr> fSplitExpr;
0108
0109 std::vector<Float_t> fMulticlassValues;
0110 std::vector<Float_t> fRegressionValues;
0111
0112 std::vector<MethodBase *> fEncapsulatedMethods;
0113
0114
0115
0116
0117 std::map<const TMVA::Event *, UInt_t> fEventToFoldMapping;
0118
0119
0120 ClassDef(MethodCrossValidation, 0);
0121 };
0122
0123 }
0124
0125 #endif