Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-18 10:11:00

0001 // @(#)root/tmva $Id$
0002 // Author: Kim Albertsson
0003 
0004 /*************************************************************************
0005  * Copyright (C) 2018, Rene Brun and Fons Rademakers.                    *
0006  * All rights reserved.                                                  *
0007  *                                                                       *
0008  * For the licensing terms see $ROOTSYS/LICENSE.                         *
0009  * For the list of contributors see $ROOTSYS/README/CREDITS.             *
0010  *************************************************************************/
0011 
0012 #ifndef ROOT_TMVA_MethodCrossValidation
0013 #define ROOT_TMVA_MethodCrossValidation
0014 
0015 //////////////////////////////////////////////////////////////////////////
0016 //                                                                      //
0017 // MethodCrossValidation                                                //
0018 //                                                                      //
0019 //////////////////////////////////////////////////////////////////////////
0020 
0021 #include "TMVA/CvSplit.h"
0022 #include "TMVA/DataSetInfo.h"
0023 #include "TMVA/MethodBase.h"
0024 
0025 #include "TString.h"
0026 
0027 #include <iostream>
0028 #include <memory>
0029 #include <vector>
0030 #include <map>
0031 
0032 namespace TMVA {
0033 
0034 class CrossValidation;
0035 class Ranking;
0036 
0037 // Looks for serialised methods of the form methodTitle + "_fold" + iFold;
0038 class MethodCrossValidation : public MethodBase {
0039 
0040    friend CrossValidation;
0041 
0042 public:
0043    // constructor for training and reading
0044    MethodCrossValidation(const TString &jobName, const TString &methodTitle, DataSetInfo &theData,
0045                          const TString &theOption = "");
0046 
0047    // constructor for calculating BDT-MVA using previously generated decision trees
0048    MethodCrossValidation(DataSetInfo &theData, const TString &theWeightFile);
0049 
0050    virtual ~MethodCrossValidation(void);
0051 
0052    // optimize tuning parameters
0053    // virtual std::map<TString,Double_t> OptimizeTuningParameters(TString fomType="ROCIntegral", TString
0054    // fitType="FitGA"); virtual void SetTuneParameters(std::map<TString,Double_t> tuneParameters);
0055 
0056    // training method
0057    void Train(void);
0058 
0059    // revoke training
0060    void Reset(void);
0061 
0062    using MethodBase::ReadWeightsFromStream;
0063 
0064    // write weights to file
0065    void AddWeightsXMLTo(void *parent) const;
0066 
0067    // read weights from file
0068    void ReadWeightsFromStream(std::istream &istr);
0069    void ReadWeightsFromXML(void *parent);
0070 
0071    // write method specific histos to target file
0072    void WriteMonitoringHistosToFile(void) const;
0073 
0074    // calculate the MVA value
0075    Double_t GetMvaValue(Double_t *err = nullptr, Double_t *errUpper = nullptr);
0076    const std::vector<Float_t> &GetMulticlassValues();
0077    const std::vector<Float_t> &GetRegressionValues();
0078 
0079    // the option handling methods
0080    void DeclareOptions();
0081    void ProcessOptions();
0082 
0083    // make ROOT-independent C++ class for classifier response (classifier-specific implementation)
0084    void MakeClassSpecific(std::ostream &, const TString &) const;
0085    void MakeClassSpecificHeader(std::ostream &, const TString &) const;
0086 
0087    void GetHelpMessage() const;
0088 
0089    const Ranking *CreateRanking();
0090    Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets);
0091 
0092 protected:
0093    void Init(void);
0094    void DeclareCompatibilityOptions();
0095 
0096 private:
0097    TString GetWeightFileNameForFold(UInt_t iFold) const;
0098    MethodBase *InstantiateMethodFromXML(TString methodTypeName, TString weightfile) const;
0099 
0100 private:
0101    TString fEncapsulatedMethodName;
0102    TString fEncapsulatedMethodTypeName;
0103    UInt_t fNumFolds;
0104    TString fOutputEnsembling;
0105 
0106    TString fSplitExprString;
0107    std::unique_ptr<CvSplitKFoldsExpr> fSplitExpr;
0108 
0109    std::vector<Float_t> fMulticlassValues;
0110    std::vector<Float_t> fRegressionValues;
0111 
0112    std::vector<MethodBase *> fEncapsulatedMethods;
0113 
0114    // Used for CrossValidation with random splits (not using the
0115    // CVSplitCrossValisationExpr functionality) to communicate Event to fold
0116    // mapping.
0117    std::map<const TMVA::Event *, UInt_t> fEventToFoldMapping;
0118 
0119    // for backward compatibility
0120    ClassDef(MethodCrossValidation, 0);
0121 };
0122 
0123 } // namespace TMVA
0124 
0125 #endif