Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-18 10:10:58

0001 // @(#)root/tmva $Id$
0002 // Author: Kim Albertsson
0003 
0004 /*************************************************************************
0005  * Copyright (C) 2018, Rene Brun and Fons Rademakers.                    *
0006  * All rights reserved.                                                  *
0007  *                                                                       *
0008  * For the licensing terms see $ROOTSYS/LICENSE.                         *
0009  * For the list of contributors see $ROOTSYS/README/CREDITS.             *
0010  *************************************************************************/
0011 
0012 #ifndef ROOT_TMVA_CvSplit
0013 #define ROOT_TMVA_CvSplit
0014 
0015 #include "TMVA/Configurable.h"
0016 #include "TMVA/Types.h"
0017 
0018 #include <Rtypes.h>
0019 #include <TFormula.h>
0020 
0021 #include <memory>
0022 #include <vector>
0023 #include <map>
0024 
0025 class TString;
0026 
0027 namespace TMVA {
0028 
0029 class CrossValidation;
0030 class DataSetInfo;
0031 class Event;
0032 
0033 /* =============================================================================
0034       TMVA::CvSplit
0035 ============================================================================= */
0036 
0037 class CvSplit : public Configurable {
0038 public:
0039    CvSplit(UInt_t numFolds);
0040    virtual ~CvSplit() {}
0041 
0042    virtual void MakeKFoldDataSet(DataSetInfo &dsi) = 0;
0043    virtual void PrepareFoldDataSet(DataSetInfo &dsi, UInt_t foldNumber, Types::ETreeType tt);
0044    virtual void RecombineKFoldDataSet(DataSetInfo &dsi, Types::ETreeType tt = Types::kTraining);
0045 
0046    UInt_t GetNumFolds() { return fNumFolds; }
0047    Bool_t NeedsRebuild() { return fMakeFoldDataSet; }
0048 
0049 protected:
0050    UInt_t fNumFolds;
0051    Bool_t fMakeFoldDataSet;
0052 
0053    std::vector<std::vector<TMVA::Event *>> fTrainEvents;
0054    std::vector<std::vector<TMVA::Event *>> fTestEvents;
0055 
0056 protected:
0057    ClassDef(CvSplit, 0);
0058 };
0059 
0060 /* =============================================================================
0061       TMVA::CvSplitKFoldsExpr
0062 ============================================================================= */
0063 
0064 class CvSplitKFoldsExpr {
0065 public:
0066    CvSplitKFoldsExpr(DataSetInfo &dsi, TString expr);
0067    ~CvSplitKFoldsExpr() {}
0068 
0069    UInt_t Eval(UInt_t numFolds, const Event *ev);
0070 
0071    static Bool_t Validate(TString expr);
0072 
0073 private:
0074    UInt_t GetSpectatorIndexForName(DataSetInfo &dsi, TString name);
0075 
0076 private:
0077    DataSetInfo &fDsi;
0078 
0079    std::vector<std::pair<Int_t, Int_t>>
0080       fFormulaParIdxToDsiSpecIdx; //! Maps parameter indicies in splitExpr to their spectator index in the datasetinfo.
0081    Int_t fIdxFormulaParNumFolds;  //! Keeps track of the index of reserved par "NumFolds" in splitExpr.
0082    TString fSplitExpr;     //! Expression used to split data into folds. Should output values between 0 and numFolds.
0083    TFormula fSplitFormula; //! TFormula for splitExpr.
0084 
0085    std::vector<Double_t> fParValues;
0086 };
0087 
0088 /* =============================================================================
0089       TMVA::CvSplitKFolds
0090 ============================================================================= */
0091 
0092 class CvSplitKFolds : public CvSplit {
0093 
0094    friend CrossValidation;
0095 
0096 public:
0097    CvSplitKFolds(UInt_t numFolds, TString splitExpr = "", Bool_t stratified = kTRUE, UInt_t seed = 100);
0098    ~CvSplitKFolds() override {}
0099 
0100    void MakeKFoldDataSet(DataSetInfo &dsi) override;
0101 
0102 private:
0103    std::vector<std::vector<Event *>> SplitSets(std::vector<TMVA::Event *> &oldSet, UInt_t numFolds, UInt_t numClasses);
0104    std::vector<UInt_t> GetEventIndexToFoldMapping(UInt_t nEntries, UInt_t numFolds, UInt_t seed = 100);
0105 
0106 private:
0107    UInt_t fSeed;
0108    TString fSplitExprString; ///<! Expression used to split data into folds. Should output values between 0 and numFolds.
0109    std::unique_ptr<CvSplitKFoldsExpr> fSplitExpr;
0110    Bool_t fStratified; ///< If true, use stratified split. (Balance class presence in each fold).
0111 
0112    // Used for CrossValidation with random splits (not using the
0113    // CVSplitKFoldsExpr functionality) to communicate Event to fold mapping.
0114    std::map<const TMVA::Event *, UInt_t> fEventToFoldMapping;
0115 
0116 private:
0117    ClassDefOverride(CvSplitKFolds, 0);
0118 };
0119 
0120 } // end namespace TMVA
0121 
0122 #endif