File indexing completed on 2025-01-30 10:22:53
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021
0022
0023
0024
0025
0026 #ifndef ROOT_TMVA_MethodKNN
0027 #define ROOT_TMVA_MethodKNN
0028
0029
0030
0031
0032
0033
0034
0035
0036
0037 #include <vector>
0038
0039
0040 #include "TMVA/MethodBase.h"
0041 #include "TMVA/ModulekNN.h"
0042
0043
0044 #include "TMVA/LDA.h"
0045
0046 namespace TMVA
0047 {
0048 namespace kNN
0049 {
0050 class ModulekNN;
0051 }
0052
0053 class MethodKNN : public MethodBase
0054 {
0055 public:
0056
0057 MethodKNN(const TString& jobName,
0058 const TString& methodTitle,
0059 DataSetInfo& theData,
0060 const TString& theOption = "KNN");
0061
0062 MethodKNN(DataSetInfo& theData,
0063 const TString& theWeightFile);
0064
0065 virtual ~MethodKNN( void );
0066
0067 virtual Bool_t HasAnalysisType( Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets );
0068
0069 void Train( void );
0070
0071 Double_t GetMvaValue( Double_t* err = nullptr, Double_t* errUpper = nullptr );
0072 const std::vector<Float_t>& GetRegressionValues();
0073
0074 using MethodBase::ReadWeightsFromStream;
0075
0076 void WriteWeightsToStream(TFile& rf) const;
0077 void AddWeightsXMLTo( void* parent ) const;
0078 void ReadWeightsFromXML( void* wghtnode );
0079
0080 void ReadWeightsFromStream(std::istream& istr);
0081 void ReadWeightsFromStream(TFile &rf);
0082
0083 const Ranking* CreateRanking();
0084
0085 protected:
0086
0087
0088 void MakeClassSpecific( std::ostream&, const TString& ) const;
0089
0090
0091 void GetHelpMessage() const;
0092
0093 private:
0094
0095
0096 void DeclareOptions();
0097 void ProcessOptions();
0098 void DeclareCompatibilityOptions();
0099
0100
0101 void Init( void );
0102
0103
0104 void MakeKNN( void );
0105
0106
0107 Double_t PolnKernel(Double_t value) const;
0108 Double_t GausKernel(const kNN::Event &event_knn, const kNN::Event &event, const std::vector<Double_t> &svec) const;
0109
0110 Double_t getKernelRadius(const kNN::List &rlist) const;
0111 const std::vector<Double_t> getRMS(const kNN::List &rlist, const kNN::Event &event_knn) const;
0112
0113 double getLDAValue(const kNN::List &rlist, const kNN::Event &event_knn);
0114
0115 private:
0116
0117
0118 Double_t fSumOfWeightsS;
0119 Double_t fSumOfWeightsB;
0120
0121 kNN::ModulekNN *fModule;
0122
0123 Int_t fnkNN;
0124 Int_t fBalanceDepth;
0125
0126 Float_t fScaleFrac;
0127 Float_t fSigmaFact;
0128
0129 TString fKernel;
0130
0131 Bool_t fTrim;
0132 Bool_t fUseKernel;
0133 Bool_t fUseWeight;
0134 Bool_t fUseLDA;
0135
0136 kNN::EventVec fEvent;
0137
0138 LDA fLDA;
0139
0140
0141 Int_t fTreeOptDepth;
0142
0143 ClassDef(MethodKNN,0);
0144 };
0145
0146 }
0147
0148 #endif