Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2024-11-15 09:56:15

0001 // @(#)root/tmva $Id$
0002 // Author: Peter Speckmayer
0003 
0004 /**********************************************************************************
0005  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis       *
0006  * Package: TMVA                                                                  *
0007  * Class  : MethodDNN                                                              *
0008  * Web    : http://tmva.sourceforge.net                                           *
0009  *                                                                                *
0010  * Description:                                                                   *
0011  *      NeuralNetwork                                                             *
0012  *                                                                                *
0013  * Authors (alphabetical):                                                        *
0014  *      Peter Speckmayer      <peter.speckmayer@gmx.at>  - CERN, Switzerland      *
0015  *      Simon Pfreundschuh    <s.pfreundschuh@gmail.com> - CERN, Switzerland      *
0016  *                                                                                *
0017  * Copyright (c) 2005-2015:                                                       *
0018  *      CERN, Switzerland                                                         *
0019  *      U. of Victoria, Canada                                                    *
0020  *      MPI-K Heidelberg, Germany                                                 *
0021  *      U. of Bonn, Germany                                                       *
0022  *                                                                                *
0023  * Redistribution and use in source and binary forms, with or without             *
0024  * modification, are permitted according to the terms listed in LICENSE           *
0025  * (http://tmva.sourceforge.net/LICENSE)                                          *
0026  **********************************************************************************/
0027 
0028 //#pragma once
0029 
0030 #ifndef ROOT_TMVA_MethodDNN
0031 #define ROOT_TMVA_MethodDNN
0032 
0033 //////////////////////////////////////////////////////////////////////////
0034 //                                                                      //
0035 // MethodDNN                                                             //
0036 //                                                                      //
0037 // Neural Network implementation                                        //
0038 //                                                                      //
0039 //////////////////////////////////////////////////////////////////////////
0040 
0041 #include <vector>
0042 #include <map>
0043 #include <string>
0044 #include <sstream>
0045 
0046 #include "TString.h"
0047 #include "TTree.h"
0048 #include "TRandom3.h"
0049 #include "TH1F.h"
0050 #include "TMVA/MethodBase.h"
0051 #include "TMVA/NeuralNet.h"
0052 
0053 #include "TMVA/Tools.h"
0054 
0055 #include "TMVA/DNN/Net.h"
0056 #include "TMVA/DNN/Minimizers.h"
0057 #include "TMVA/DNN/Architectures/Reference.h"
0058 
0059 #ifdef R__HAS_TMVACPU
0060 #define DNNCPU
0061 #endif
0062 #ifdef R__HAS_TMVAGPU
0063 //#define DNNCUDA
0064 #endif
0065 
0066 #ifdef DNNCPU
0067 #include "TMVA/DNN/Architectures/Cpu.h"
0068 #endif
0069 
0070 #ifdef DNNCUDA
0071 #include "TMVA/DNN/Architectures/Cuda.h"
0072 #endif
0073 
0074 namespace TMVA {
0075 
0076 class MethodDNN : public MethodBase
0077 {
0078    friend struct TestMethodDNNValidationSize;
0079 
0080    using Architecture_t = DNN::TReference<Float_t>;
0081    using Net_t          = DNN::TNet<Architecture_t>;
0082    using Matrix_t       = typename Architecture_t::Matrix_t;
0083    using Scalar_t       = typename Architecture_t::Scalar_t;
0084 
0085 private:
0086    using LayoutVector_t   = std::vector<std::pair<int, DNN::EActivationFunction>>;
0087    using KeyValueVector_t = std::vector<std::map<TString, TString>>;
0088 
0089    struct TTrainingSettings
0090    {
0091        size_t                batchSize;
0092        size_t                testInterval;
0093        size_t                convergenceSteps;
0094        DNN::ERegularization  regularization;
0095        Double_t              learningRate;
0096        Double_t              momentum;
0097        Double_t              weightDecay;
0098        std::vector<Double_t> dropoutProbabilities;
0099        bool                  multithreading;
0100    };
0101 
0102    // the option handling methods
0103    void DeclareOptions();
0104    void ProcessOptions();
0105 
0106    UInt_t GetNumValidationSamples();
0107 
0108    // general helper functions
0109    void     Init();
0110 
0111    Net_t                fNet;
0112    DNN::EInitialization fWeightInitialization;
0113    DNN::EOutputFunction fOutputFunction;
0114 
0115    TString                        fLayoutString;
0116    TString                        fErrorStrategy;
0117    TString                        fTrainingStrategyString;
0118    TString                        fWeightInitializationString;
0119    TString                        fArchitectureString;
0120    TString                        fValidationSize;
0121    LayoutVector_t                 fLayout;
0122    std::vector<TTrainingSettings> fTrainingSettings;
0123    bool                           fResume;
0124 
0125    KeyValueVector_t fSettings;
0126 
0127    ClassDef(MethodDNN,0); // neural network
0128 
0129    static inline void WriteMatrixXML(void *parent, const char *name,
0130                                      const TMatrixT<Double_t> &X);
0131    static inline void ReadMatrixXML(void *xml, const char *name,
0132                                     TMatrixT<Double_t> &X);
0133 protected:
0134 
0135    void MakeClassSpecific( std::ostream&, const TString& ) const;
0136    void GetHelpMessage() const;
0137 
0138 public:
0139 
0140    // Standard Constructors
0141    MethodDNN(const TString& jobName,
0142              const TString&  methodTitle,
0143              DataSetInfo& theData,
0144              const TString& theOption);
0145    MethodDNN(DataSetInfo& theData,
0146              const TString& theWeightFile);
0147    virtual ~MethodDNN();
0148 
0149    virtual Bool_t HasAnalysisType(Types::EAnalysisType type,
0150                                   UInt_t numberClasses,
0151                                   UInt_t numberTargets );
0152    LayoutVector_t   ParseLayoutString(TString layerSpec);
0153    KeyValueVector_t ParseKeyValueString(TString parseString,
0154                                       TString blockDelim,
0155                                       TString tokenDelim);
0156    void Train();
0157    void TrainGpu();
0158    void TrainCpu();
0159 
0160    virtual Double_t GetMvaValue( Double_t* err = nullptr, Double_t* errUpper = nullptr );
0161    virtual const std::vector<Float_t>& GetRegressionValues();
0162    virtual const std::vector<Float_t>& GetMulticlassValues();
0163 
0164    using MethodBase::ReadWeightsFromStream;
0165 
0166    // write weights to stream
0167    void AddWeightsXMLTo     ( void* parent ) const;
0168 
0169    // read weights from stream
0170    void ReadWeightsFromStream( std::istream & i );
0171    void ReadWeightsFromXML   ( void* wghtnode );
0172 
0173    // ranking of input variables
0174    const Ranking* CreateRanking();
0175 
0176 };
0177 
0178 inline void MethodDNN::WriteMatrixXML(void *parent,
0179                                       const char *name,
0180                                       const TMatrixT<Double_t> &X)
0181 {
0182    std::stringstream matrixStringStream("");
0183    matrixStringStream.precision( 16 );
0184 
0185    for (size_t i = 0; i < (size_t) X.GetNrows(); i++)
0186    {
0187       for (size_t j = 0; j < (size_t) X.GetNcols(); j++)
0188       {
0189          matrixStringStream << std::scientific << X(i,j) << " ";
0190       }
0191    }
0192    std::string s = matrixStringStream.str();
0193    void* matxml = gTools().xmlengine().NewChild(parent, nullptr, name);
0194    gTools().xmlengine().NewAttr(matxml, nullptr, "rows",
0195                                 gTools().StringFromInt((int)X.GetNrows()));
0196    gTools().xmlengine().NewAttr(matxml, nullptr, "cols",
0197                                 gTools().StringFromInt((int)X.GetNcols()));
0198    gTools().xmlengine().AddRawLine (matxml, s.c_str());
0199 }
0200 
0201 inline void MethodDNN::ReadMatrixXML(void *xml,
0202                                      const char *name,
0203                                      TMatrixT<Double_t> &X)
0204 {
0205    void *matrixXML = gTools().GetChild(xml, name);
0206    size_t rows, cols;
0207    gTools().ReadAttr(matrixXML, "rows", rows);
0208    gTools().ReadAttr(matrixXML, "cols", cols);
0209 
0210    const char * matrixString = gTools().xmlengine().GetNodeContent(matrixXML);
0211    std::stringstream matrixStringStream(matrixString);
0212 
0213    for (size_t i = 0; i < rows; i++)
0214    {
0215       for (size_t j = 0; j < cols; j++)
0216       {
0217          matrixStringStream >> X(i,j);
0218       }
0219    }
0220 }
0221 } // namespace TMVA
0222 
0223 #endif