Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-18 10:11:01

0001 // @(#)root/tmva $Id$
0002 // Author: Marcin Wolter, Andrzej Zemla
0003 
0004 /**********************************************************************************
0005  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis       *
0006  * Package: TMVA                                                                  *
0007  * Class  : MethodSVM                                                             *
0008  *                                             *
0009  *                                                                                *
0010  * Description:                                                                   *
0011  *      Support Vector Machine                                                    *
0012  *                                                                                *
0013  * Authors (alphabetical):                                                        *
0014  *      Marcin Wolter  <Marcin.Wolter@cern.ch> - IFJ PAN, Krakow, Poland          *
0015  *      Andrzej Zemla  <azemla@cern.ch>         - IFJ PAN, Krakow, Poland         *
0016  *      (IFJ PAN: Henryk Niewodniczanski Inst. Nucl. Physics, Krakow, Poland)     *
0017  *                                                                                *
0018  * Introduction of regression by:                                                 *
0019  *      Krzysztof Danielowski <danielow@cern.ch> - IFJ PAN & AGH, Krakow, Poland  *
0020  *      Kamil Kraszewski      <kalq@cern.ch>     - IFJ PAN & UJ, Krakow, Poland   *
0021  *      Maciej Kruk           <mkruk@cern.ch>    - IFJ PAN & AGH, Krakow, Poland  *
0022  *                                                                                *
0023  * Copyright (c) 2005:                                                            *
0024  *      CERN, Switzerland                                                         *
0025  *      MPI-K Heidelberg, Germany                                                 *
0026  *      PAN, Krakow, Poland                                                       *
0027  *                                                                                *
0028  * Redistribution and use in source and binary forms, with or without             *
0029  * modification, are permitted according to the terms listed in LICENSE           *
0030  * (see tmva/doc/LICENSE)                                          *
0031  **********************************************************************************/
0032 
0033 #ifndef ROOT_TMVA_MethodSVM
0034 #define ROOT_TMVA_MethodSVM
0035 
0036 //////////////////////////////////////////////////////////////////////////
0037 //                                                                      //
0038 // MethodSVM                                                            //
0039 //                                                                      //
0040 // SMO Platt's SVM classifier with Keerthi & Shavade improvements       //
0041 //                                                                      //
0042 //////////////////////////////////////////////////////////////////////////
0043 
0044 #include "TMVA/MethodBase.h"
0045 #include "TMatrixDfwd.h"
0046 #include <string>
0047 #include <vector>
0048 #include <map>
0049 
0050 #ifndef ROOT_TMVA_TVectorD
0051 #include "TVectorD.h"
0052 #include "TMVA/SVKernelFunction.h"
0053 #endif
0054 
0055 namespace TMVA
0056 {
0057    class SVWorkingSet;
0058    class SVEvent;
0059    class SVKernelFunction;
0060 
0061    class MethodSVM : public MethodBase {
0062 
0063    public:
0064 
0065       MethodSVM( const TString& jobName, const TString& methodTitle, DataSetInfo& theData,
0066                  const TString& theOption = "" );
0067 
0068       MethodSVM( DataSetInfo& theData, const TString& theWeightFile);
0069 
0070       virtual ~MethodSVM( void );
0071 
0072       virtual Bool_t HasAnalysisType( Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets );
0073 
0074       // optimise tuning parameters
0075       virtual std::map<TString,Double_t> OptimizeTuningParameters(TString fomType="ROCIntegral", TString fitType="Minuit");
0076       virtual void SetTuneParameters(std::map<TString,Double_t> tuneParameters);
0077       std::vector<TMVA::SVKernelFunction::EKernelType> MakeKernelList(std::string multiKernels, TString kernel);
0078       std::map< TString,std::vector<Double_t> > GetTuningOptions();
0079 
0080       // training method
0081       void Train( void );
0082 
0083       // revoke training (required for optimise tuning parameters)
0084       void Reset( void );
0085 
0086       using MethodBase::ReadWeightsFromStream;
0087 
0088       // write weights to file
0089       void WriteWeightsToStream( TFile& fout   ) const;
0090       void AddWeightsXMLTo     ( void*  parent ) const;
0091 
0092       // read weights from file
0093       void ReadWeightsFromStream( std::istream& istr );
0094       void ReadWeightsFromStream( TFile& fFin     );
0095       void ReadWeightsFromXML   ( void*  wghtnode );
0096       // calculate the MVA value
0097 
0098       Double_t GetMvaValue( Double_t* err = nullptr, Double_t* errUpper = nullptr );
0099       const std::vector<Float_t>& GetRegressionValues();
0100 
0101       void Init( void );
0102 
0103       // ranking of input variables
0104       const Ranking* CreateRanking() { return nullptr; }
0105 
0106       // for SVM optimisation
0107       void SetGamma(Double_t g){fGamma = g;}
0108       void SetCost(Double_t c){fCost = c;}
0109       void SetMGamma(std::string & mg);
0110       void SetOrder(Double_t o){fOrder = o;}
0111       void SetTheta(Double_t t){fTheta = t;}
0112       void SetKappa(Double_t k){fKappa = k;}
0113       void SetMult(Double_t m){fMult = m;}
0114 
0115       void GetMGamma(const std::vector<float> & gammas);
0116 
0117    protected:
0118 
0119       // make ROOT-independent C++ class for classifier response (classifier-specific implementation)
0120       void MakeClassSpecific( std::ostream&, const TString& ) const;
0121 
0122       // get help message text
0123       void GetHelpMessage() const;
0124 
0125    private:
0126 
0127       // the option handling methods
0128       void DeclareOptions();
0129       void DeclareCompatibilityOptions();
0130       void ProcessOptions();
0131       Double_t getLoss( TString lossFunction );
0132 
0133       Float_t                       fCost;                ///< cost value
0134       Float_t                       fTolerance;           ///< tolerance parameter
0135       UInt_t                        fMaxIter;             ///< max number of iteration
0136       UShort_t                      fNSubSets;            ///< nr of subsets, default 1
0137       Float_t                       fBparm;               ///< free plane coefficient
0138       Float_t                       fGamma;               ///< RBF Kernel parameter
0139       SVWorkingSet*                 fWgSet;               ///< svm working set
0140       std::vector<TMVA::SVEvent*>*  fInputData;           ///< vector of training data in SVM format
0141       std::vector<TMVA::SVEvent*>*  fSupportVectors;      ///< contains support vectors
0142       SVKernelFunction*             fSVKernelFunction;    ///< kernel function
0143 
0144       TVectorD*                     fMinVars;             ///< for normalization //is it still needed??
0145       TVectorD*                     fMaxVars;             ///< for normalization //is it still needed??
0146 
0147       // for kernel functions
0148       TString                       fTheKernel;           ///< kernel name
0149       Float_t                       fDoubleSigmaSquared;  ///< for RBF Kernel
0150       Int_t                         fOrder;               ///< for Polynomial Kernel ( polynomial order )
0151       Float_t                       fTheta;               ///< for Sigmoidal Kernel
0152       Float_t                       fKappa;               ///< for Sigmoidal Kernel
0153       Float_t                       fMult;
0154       std::vector<Float_t>          fmGamma;              ///< vector of gammas for multi-gaussian kernel
0155       Float_t                       fNumVars;             ///< number of input variables for multi-gaussian
0156       std::vector<TString>          fVarNames;
0157       std::string                   fGammas;
0158       std::string                   fGammaList;
0159       std::string                   fTune;                ///< Specify parameters to be tuned
0160       std::string                   fMultiKernels;
0161 
0162       Int_t                 fDataSize;
0163       TString fLoss;
0164 
0165       ClassDef(MethodSVM,0);  // Support Vector Machine
0166    };
0167 
0168 } // namespace TMVA
0169 
0170 #endif // MethodSVM_H