File indexing completed on 2025-01-18 10:10:58
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012 #ifndef ROOT_TMVA_CvSplit
0013 #define ROOT_TMVA_CvSplit
0014
0015 #include "TMVA/Configurable.h"
0016 #include "TMVA/Types.h"
0017
0018 #include <Rtypes.h>
0019 #include <TFormula.h>
0020
0021 #include <memory>
0022 #include <vector>
0023 #include <map>
0024
0025 class TString;
0026
0027 namespace TMVA {
0028
0029 class CrossValidation;
0030 class DataSetInfo;
0031 class Event;
0032
0033
0034
0035
0036
0037 class CvSplit : public Configurable {
0038 public:
0039 CvSplit(UInt_t numFolds);
0040 virtual ~CvSplit() {}
0041
0042 virtual void MakeKFoldDataSet(DataSetInfo &dsi) = 0;
0043 virtual void PrepareFoldDataSet(DataSetInfo &dsi, UInt_t foldNumber, Types::ETreeType tt);
0044 virtual void RecombineKFoldDataSet(DataSetInfo &dsi, Types::ETreeType tt = Types::kTraining);
0045
0046 UInt_t GetNumFolds() { return fNumFolds; }
0047 Bool_t NeedsRebuild() { return fMakeFoldDataSet; }
0048
0049 protected:
0050 UInt_t fNumFolds;
0051 Bool_t fMakeFoldDataSet;
0052
0053 std::vector<std::vector<TMVA::Event *>> fTrainEvents;
0054 std::vector<std::vector<TMVA::Event *>> fTestEvents;
0055
0056 protected:
0057 ClassDef(CvSplit, 0);
0058 };
0059
0060
0061
0062
0063
0064 class CvSplitKFoldsExpr {
0065 public:
0066 CvSplitKFoldsExpr(DataSetInfo &dsi, TString expr);
0067 ~CvSplitKFoldsExpr() {}
0068
0069 UInt_t Eval(UInt_t numFolds, const Event *ev);
0070
0071 static Bool_t Validate(TString expr);
0072
0073 private:
0074 UInt_t GetSpectatorIndexForName(DataSetInfo &dsi, TString name);
0075
0076 private:
0077 DataSetInfo &fDsi;
0078
0079 std::vector<std::pair<Int_t, Int_t>>
0080 fFormulaParIdxToDsiSpecIdx;
0081 Int_t fIdxFormulaParNumFolds;
0082 TString fSplitExpr;
0083 TFormula fSplitFormula;
0084
0085 std::vector<Double_t> fParValues;
0086 };
0087
0088
0089
0090
0091
0092 class CvSplitKFolds : public CvSplit {
0093
0094 friend CrossValidation;
0095
0096 public:
0097 CvSplitKFolds(UInt_t numFolds, TString splitExpr = "", Bool_t stratified = kTRUE, UInt_t seed = 100);
0098 ~CvSplitKFolds() override {}
0099
0100 void MakeKFoldDataSet(DataSetInfo &dsi) override;
0101
0102 private:
0103 std::vector<std::vector<Event *>> SplitSets(std::vector<TMVA::Event *> &oldSet, UInt_t numFolds, UInt_t numClasses);
0104 std::vector<UInt_t> GetEventIndexToFoldMapping(UInt_t nEntries, UInt_t numFolds, UInt_t seed = 100);
0105
0106 private:
0107 UInt_t fSeed;
0108 TString fSplitExprString;
0109 std::unique_ptr<CvSplitKFoldsExpr> fSplitExpr;
0110 Bool_t fStratified;
0111
0112
0113
0114 std::map<const TMVA::Event *, UInt_t> fEventToFoldMapping;
0115
0116 private:
0117 ClassDefOverride(CvSplitKFolds, 0);
0118 };
0119
0120 }
0121
0122 #endif