Warning, file /include/root/TMVA/MethodKNN.h was not indexed
or was modified since last indexation (in which case cross-reference links may be missing, inaccurate or erroneous).
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021
0022
0023
0024
0025
0026 #ifndef ROOT_TMVA_MethodKNN
0027 #define ROOT_TMVA_MethodKNN
0028
0029
0030
0031
0032
0033
0034
0035
0036
0037 #include <vector>
0038
0039
0040 #include "TMVA/MethodBase.h"
0041 #include "TMVA/ModulekNN.h"
0042
0043
0044 #include "TMVA/LDA.h"
0045
0046 namespace TMVA
0047 {
0048 namespace kNN
0049 {
0050 class ModulekNN;
0051 }
0052
0053 class MethodKNN : public MethodBase
0054 {
0055 public:
0056
0057 MethodKNN(const TString& jobName,
0058 const TString& methodTitle,
0059 DataSetInfo& theData,
0060 const TString& theOption = "KNN");
0061
0062 MethodKNN(DataSetInfo& theData,
0063 const TString& theWeightFile);
0064
0065 virtual ~MethodKNN( void );
0066
0067 virtual Bool_t HasAnalysisType( Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets );
0068
0069 void Train( void );
0070
0071 Double_t GetMvaValue( Double_t* err = nullptr, Double_t* errUpper = nullptr );
0072 const std::vector<Float_t>& GetRegressionValues();
0073
0074 using MethodBase::ReadWeightsFromStream;
0075
0076 void WriteWeightsToStream(TFile& rf) const;
0077 void AddWeightsXMLTo( void* parent ) const;
0078 void ReadWeightsFromXML( void* wghtnode );
0079
0080 void ReadWeightsFromStream(std::istream& istr);
0081 void ReadWeightsFromStream(TFile &rf);
0082
0083 const Ranking* CreateRanking();
0084
0085 protected:
0086
0087
0088 void MakeClassSpecific( std::ostream&, const TString& ) const;
0089
0090
0091 void GetHelpMessage() const;
0092
0093 private:
0094
0095
0096 void DeclareOptions();
0097 void ProcessOptions();
0098 void DeclareCompatibilityOptions();
0099
0100
0101 void Init( void );
0102
0103
0104 void MakeKNN( void );
0105
0106
0107 Double_t PolnKernel(Double_t value) const;
0108 Double_t GausKernel(const kNN::Event &event_knn, const kNN::Event &event, const std::vector<Double_t> &svec) const;
0109
0110 Double_t getKernelRadius(const kNN::List &rlist) const;
0111 const std::vector<Double_t> getRMS(const kNN::List &rlist, const kNN::Event &event_knn) const;
0112
0113 double getLDAValue(const kNN::List &rlist, const kNN::Event &event_knn);
0114
0115 private:
0116
0117
0118 Double_t fSumOfWeightsS;
0119 Double_t fSumOfWeightsB;
0120
0121 kNN::ModulekNN *fModule;
0122
0123 Int_t fnkNN;
0124 Int_t fBalanceDepth;
0125
0126 Float_t fScaleFrac;
0127 Float_t fSigmaFact;
0128
0129 TString fKernel;
0130
0131 Bool_t fTrim;
0132 Bool_t fUseKernel;
0133 Bool_t fUseWeight;
0134 Bool_t fUseLDA;
0135
0136 kNN::EventVec fEvent;
0137
0138 LDA fLDA;
0139
0140
0141 Int_t fTreeOptDepth;
0142
0143 ClassDef(MethodKNN,0);
0144 };
0145
0146 }
0147
0148 #endif