Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-30 10:22:53

0001 // @(#)root/tmva $Id$
0002 // Author: Rustem Ospanov
0003 
0004 /**********************************************************************************
0005  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis       *
0006  * Package: TMVA                                                                  *
0007  * Class  : MethodKNN                                                             *
0008  *                                             *
0009  *                                                                                *
0010  * Description:                                                                   *
0011  *      Analysis of k-nearest neighbor                                            *
0012  *                                                                                *
0013  * Author:                                                                        *
0014  *      Rustem Ospanov <rustem@fnal.gov> - U. of Texas at Austin, USA             *
0015  *                                                                                *
0016  * Copyright (c) 2007:                                                            *
0017  *      CERN, Switzerland                                                         *
0018  *      MPI-K Heidelberg, Germany                                                 *
0019  *      U. of Texas at Austin, USA                                                *
0020  *                                                                                *
0021  * Redistribution and use in source and binary forms, with or without             *
0022  * modification, are permitted according to the terms listed in LICENSE           *
0023  * (see tmva/doc/LICENSE)                                          *
0024  **********************************************************************************/
0025 
0026 #ifndef ROOT_TMVA_MethodKNN
0027 #define ROOT_TMVA_MethodKNN
0028 
0029 //////////////////////////////////////////////////////////////////////////
0030 //                                                                      //
0031 // MethodKNN                                                            //
0032 //                                                                      //
0033 // Analysis of k-nearest neighbor                                       //
0034 //                                                                      //
0035 //////////////////////////////////////////////////////////////////////////
0036 
0037 #include <vector>
0038 
0039 // Local
0040 #include "TMVA/MethodBase.h"
0041 #include "TMVA/ModulekNN.h"
0042 
0043 // SVD and linear discriminant code
0044 #include "TMVA/LDA.h"
0045 
0046 namespace TMVA
0047 {
0048    namespace kNN
0049    {
0050       class ModulekNN;
0051    }
0052 
0053    class MethodKNN : public MethodBase
0054    {
0055    public:
0056 
0057       MethodKNN(const TString& jobName,
0058                 const TString& methodTitle,
0059                 DataSetInfo& theData,
0060                 const TString& theOption = "KNN");
0061 
0062       MethodKNN(DataSetInfo& theData,
0063                 const TString& theWeightFile);
0064 
0065       virtual ~MethodKNN( void );
0066 
0067       virtual Bool_t HasAnalysisType( Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets );
0068 
0069       void Train( void );
0070 
0071       Double_t GetMvaValue( Double_t* err = nullptr, Double_t* errUpper = nullptr );
0072       const std::vector<Float_t>& GetRegressionValues();
0073 
0074       using MethodBase::ReadWeightsFromStream;
0075 
0076       void WriteWeightsToStream(TFile& rf) const;
0077       void AddWeightsXMLTo( void* parent ) const;
0078       void ReadWeightsFromXML( void* wghtnode );
0079 
0080       void ReadWeightsFromStream(std::istream& istr);
0081       void ReadWeightsFromStream(TFile &rf);
0082 
0083       const Ranking* CreateRanking();
0084 
0085    protected:
0086 
0087       // make ROOT-independent C++ class for classifier response (classifier-specific implementation)
0088       void MakeClassSpecific( std::ostream&, const TString& ) const;
0089 
0090       // get help message text
0091       void GetHelpMessage() const;
0092 
0093    private:
0094 
0095       // the option handling methods
0096       void DeclareOptions();
0097       void ProcessOptions();
0098       void DeclareCompatibilityOptions();
0099 
0100       // default initialisation called by all constructors
0101       void Init( void );
0102 
0103       // create kd-tree (binary tree) structure
0104       void MakeKNN( void );
0105 
0106       // polynomial and Gaussian kernel weight function
0107       Double_t PolnKernel(Double_t value) const;
0108       Double_t GausKernel(const kNN::Event &event_knn, const kNN::Event &event, const std::vector<Double_t> &svec) const;
0109 
0110       Double_t getKernelRadius(const kNN::List &rlist) const;
0111       const std::vector<Double_t> getRMS(const kNN::List &rlist, const kNN::Event &event_knn) const;
0112 
0113       double getLDAValue(const kNN::List &rlist, const kNN::Event &event_knn);
0114 
0115    private:
0116 
0117       // number of events (sumOfWeights)
0118       Double_t fSumOfWeightsS;        ///< sum-of-weights for signal training events
0119       Double_t fSumOfWeightsB;        ///< sum-of-weights for background training events
0120 
0121       kNN::ModulekNN *fModule;        ///<! module where all work is done
0122 
0123       Int_t fnkNN;            ///< number of k-nearest neighbors
0124       Int_t fBalanceDepth;    ///< number of binary tree levels used for balancing tree
0125 
0126       Float_t fScaleFrac;     ///< fraction of events used to compute variable width
0127       Float_t fSigmaFact;     ///< scale factor for Gaussian sigma in Gaus. kernel
0128 
0129       TString fKernel;        ///< ="Gaus","Poln" - kernel type for smoothing
0130 
0131       Bool_t fTrim;           ///< set equal number of signal and background events
0132       Bool_t fUseKernel;      ///< use polynomial kernel weight function
0133       Bool_t fUseWeight;      ///< use weights to count kNN
0134       Bool_t fUseLDA;         ///< use local linear discriminant analysis to compute MVA
0135 
0136       kNN::EventVec fEvent;   ///<! (untouched) events used for learning
0137 
0138       LDA fLDA;               ///<! Experimental feature for local knn analysis
0139 
0140       // for backward compatibility
0141       Int_t fTreeOptDepth;    ///< number of binary tree levels used for optimization
0142 
0143       ClassDef(MethodKNN,0); // k Nearest Neighbour classifier
0144    };
0145 
0146 } // namespace TMVA
0147 
0148 #endif // MethodKNN