Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-10-23 09:24:34

0001 // @(#)root/tmva/tmva/dnn:$Id$
0002 // Author: Vladimir Ilievski, Saurav Shekhar
0003 
0004 /**********************************************************************************
0005  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis       *
0006  * Package: TMVA                                                                  *
0007  * Class  : MethodDL                                                              *
0008  *                                             *
0009  *                                                                                *
0010  * Description:                                                                   *
0011  *      Deep Neural Network Method                                                *
0012  *                                                                                *
0013  * Authors (alphabetical):                                                        *
0014  *      Vladimir Ilievski  <ilievski.vladimir@live.com> - CERN, Switzerland       *
0015  *      Saurav Shekhar     <sauravshekhar01@gmail.com> - ETH Zurich, Switzerland  *
0016  *                                                                                *
0017  * Copyright (c) 2005-2015:                                                       *
0018  *      CERN, Switzerland                                                         *
0019  *      U. of Victoria, Canada                                                    *
0020  *      MPI-K Heidelberg, Germany                                                 *
0021  *      U. of Bonn, Germany                                                       *
0022  *                                                                                *
0023  * Redistribution and use in source and binary forms, with or without             *
0024  * modification, are permitted according to the terms listed in LICENSE           *
0025  * (see tmva/doc/LICENSE)                                          *
0026  **********************************************************************************/
0027 
0028 #ifndef ROOT_TMVA_MethodDL
0029 #define ROOT_TMVA_MethodDL
0030 
0031 //////////////////////////////////////////////////////////////////////////
0032 //                                                                      //
0033 // MethodDL                                                             //
0034 //                                                                      //
0035 // Method class for all Deep Learning Networks                          //
0036 //                                                                      //
0037 //////////////////////////////////////////////////////////////////////////
0038 
0039 #include "TString.h"
0040 
0041 #include "TMVA/MethodBase.h"
0042 #include "TMVA/Types.h"
0043 
0044 #include "TMVA/DNN/Architectures/Reference.h"
0045 
0046 //#ifdef R__HAS_TMVACPU
0047 #include "TMVA/DNN/Architectures/Cpu.h"
0048 //#endif
0049 
0050 #if 0
0051 #ifdef R__HAS_TMVAGPU
0052 #include "TMVA/DNN/Architectures/Cuda.h"
0053 #ifdef R__HAS_CUDNN
0054 #include "TMVA/DNN/Architectures/TCudnn.h"
0055 #endif
0056 #endif
0057 #endif
0058 
0059 #include "TMVA/DNN/Functions.h"
0060 #include "TMVA/DNN/DeepNet.h"
0061 
0062 #include <vector>
0063 #include <map>
0064 
0065 #ifdef R__HAS_TMVAGPU
0066 //#define USE_GPU_INFERENCE
0067 #endif
0068 
0069 namespace TMVA {
0070 
0071 /*! All of the options that can be specified in the training string */
0072 struct TTrainingSettings {
0073    size_t batchSize;
0074    size_t testInterval;
0075    size_t convergenceSteps;
0076    size_t maxEpochs;
0077    DNN::ERegularization regularization;
0078    DNN::EOptimizer optimizer;
0079    TString optimizerName;
0080    Double_t learningRate;
0081    Double_t momentum;
0082    Double_t weightDecay;
0083    std::vector<Double_t> dropoutProbabilities;
0084    std::map<TString,double> optimizerParams;
0085    bool multithreading;
0086 };
0087 
0088 
0089 class MethodDL : public MethodBase {
0090 
0091 private:
0092    // Key-Value vector type, contining the values for the training options
0093    using KeyValueVector_t = std::vector<std::map<TString, TString>>;
0094 
0095 // #ifdef R__HAS_TMVAGPU
0096 // #ifdef R__HAS_CUDNN
0097 //    using ArchitectureImpl_t = TMVA::DNN::TCudnn<Float_t>;
0098 // #else
0099 //   using ArchitectureImpl_t = TMVA::DNN::TCuda<Float_t>;
0100 // #endif
0101 // #else
0102 // do not use GPU architecture for evaluation. It is too slow for batch size=1
0103    using ArchitectureImpl_t = TMVA::DNN::TCpu<Float_t>;
0104 // #endif
0105 
0106    using DeepNetImpl_t = TMVA::DNN::TDeepNet<ArchitectureImpl_t>;
0107    using MatrixImpl_t =  typename ArchitectureImpl_t::Matrix_t;
0108    using TensorImpl_t =  typename ArchitectureImpl_t::Tensor_t;
0109    using ScalarImpl_t =  typename ArchitectureImpl_t::Scalar_t;
0110    using HostBufferImpl_t = typename ArchitectureImpl_t::HostBuffer_t;
0111 
0112    /*! The option handling methods */
0113    void DeclareOptions();
0114    void ProcessOptions();
0115 
0116    void Init();
0117 
0118    // Function to parse the layout of the input
0119    void ParseInputLayout();
0120    void ParseBatchLayout();
0121 
0122    /*! After calling the ProcesOptions(), all of the options are parsed,
0123     *  so using the parsed options, and given the architecture and the
0124     *  type of the layers, we build the Deep Network passed as
0125     *  a reference in the function. */
0126    template <typename Architecture_t, typename Layer_t>
0127    void CreateDeepNet(DNN::TDeepNet<Architecture_t, Layer_t> &deepNet,
0128                       std::vector<DNN::TDeepNet<Architecture_t, Layer_t>> &nets);
0129 
0130    template <typename Architecture_t, typename Layer_t>
0131    void ParseDenseLayer(DNN::TDeepNet<Architecture_t, Layer_t> &deepNet,
0132                         std::vector<DNN::TDeepNet<Architecture_t, Layer_t>> &nets, TString layerString, TString delim);
0133 
0134    template <typename Architecture_t, typename Layer_t>
0135    void ParseConvLayer(DNN::TDeepNet<Architecture_t, Layer_t> &deepNet,
0136                        std::vector<DNN::TDeepNet<Architecture_t, Layer_t>> &nets, TString layerString, TString delim);
0137 
0138    template <typename Architecture_t, typename Layer_t>
0139    void ParseMaxPoolLayer(DNN::TDeepNet<Architecture_t, Layer_t> &deepNet,
0140                           std::vector<DNN::TDeepNet<Architecture_t, Layer_t>> &nets, TString layerString,
0141                           TString delim);
0142 
0143    template <typename Architecture_t, typename Layer_t>
0144    void ParseReshapeLayer(DNN::TDeepNet<Architecture_t, Layer_t> &deepNet,
0145                           std::vector<DNN::TDeepNet<Architecture_t, Layer_t>> &nets, TString layerString,
0146                           TString delim);
0147 
0148    template <typename Architecture_t, typename Layer_t>
0149    void ParseBatchNormLayer(DNN::TDeepNet<Architecture_t, Layer_t> &deepNet,
0150                           std::vector<DNN::TDeepNet<Architecture_t, Layer_t>> &nets, TString layerString,
0151                           TString delim);
0152 
0153    enum ERecurrentLayerType { kLayerRNN = 0, kLayerLSTM = 1, kLayerGRU = 2 };
0154    template <typename Architecture_t, typename Layer_t>
0155    void ParseRecurrentLayer(ERecurrentLayerType type, DNN::TDeepNet<Architecture_t, Layer_t> &deepNet,
0156                        std::vector<DNN::TDeepNet<Architecture_t, Layer_t>> &nets, TString layerString, TString delim);
0157 
0158 
0159    /// train of deep neural network using the defined architecture
0160    template <typename Architecture_t>
0161    void TrainDeepNet();
0162 
0163    /// perform prediction of the deep neural network
0164    /// using batches (called by GetMvaValues)
0165    template <typename Architecture_t>
0166    std::vector<Double_t> PredictDeepNet(Long64_t firstEvt, Long64_t lastEvt, size_t batchSize, Bool_t logProgress);
0167 
0168    /// Get the input event tensor for evaluation
0169    /// Internal function to fill the fXInput tensor with the correct shape from TMVA current Event class
0170    void FillInputTensor();
0171 
0172    /// parce the validation string and return the number of event data used for validation
0173    UInt_t GetNumValidationSamples();
0174 
0175    // cudnn implementation needs this format
0176    /** Contains the batch size (no. of images in the batch), input depth (no. channels)
0177     *  and further input dimensions of the data (image height, width ...)*/
0178    std::vector<size_t> fInputShape;
0179 
0180    // The size of the batch, i.e. the number of images that are contained in the batch, is either set to be the depth
0181    // or the height of the batch
0182    size_t fBatchDepth;  ///< The depth of the batch used to train the deep net.
0183    size_t fBatchHeight; ///< The height of the batch used to train the deep net.
0184    size_t fBatchWidth;  ///< The width of the batch used to train the deep net.
0185 
0186    size_t fRandomSeed;  ///<The random seed used to initialize the weights and shuffling batches (default is zero)
0187 
0188    DNN::EInitialization fWeightInitialization; ///< The initialization method
0189    DNN::EOutputFunction fOutputFunction;       ///< The output function for making the predictions
0190    DNN::ELossFunction   fLossFunction;         ///< The loss function
0191 
0192    TString fInputLayoutString;          ///< The string defining the layout of the input
0193    TString fBatchLayoutString;          ///< The string defining the layout of the batch
0194    TString fLayoutString;               ///< The string defining the layout of the deep net
0195    TString fErrorStrategy;              ///< The string defining the error strategy for training
0196    TString fTrainingStrategyString;     ///< The string defining the training strategy
0197    TString fWeightInitializationString; ///< The string defining the weight initialization method
0198    TString fArchitectureString;         ///< The string defining the architecture: CPU or GPU
0199    TString fNumValidationString;        ///< The string defining the number (or percentage) of training data used for validation
0200    bool fResume;
0201    bool fBuildNet;                     ///< Flag to control whether to build fNet, the stored network used for the evaluation
0202 
0203    KeyValueVector_t fSettings;                       ///< Map for the training strategy
0204    std::vector<TTrainingSettings> fTrainingSettings; ///< The vector defining each training strategy
0205 
0206    TensorImpl_t fXInput;                 // input tensor used to evaluate fNet
0207    HostBufferImpl_t fXInputBuffer;        // input host buffer corresponding to X (needed for GPU implementation)
0208    std::unique_ptr<MatrixImpl_t> fYHat;   // output prediction matrix of fNet
0209    std::unique_ptr<DeepNetImpl_t> fNet;
0210 
0211 
0212    ClassDef(MethodDL, 0);
0213 
0214 protected:
0215    // provide a help message
0216    void GetHelpMessage() const;
0217 
0218    virtual std::vector<Double_t> GetMvaValues(Long64_t firstEvt, Long64_t lastEvt, Bool_t logProgress);
0219 
0220 
0221 public:
0222    /*! Constructor */
0223    MethodDL(const TString &jobName, const TString &methodTitle, DataSetInfo &theData, const TString &theOption);
0224 
0225    /*! Constructor */
0226    MethodDL(DataSetInfo &theData, const TString &theWeightFile);
0227 
0228    /*! Virtual Destructor */
0229    virtual ~MethodDL();
0230 
0231    /*! Function for parsing the training settings, provided as a string
0232     *  in a key-value form.  */
0233    KeyValueVector_t ParseKeyValueString(TString parseString, TString blockDelim, TString tokenDelim);
0234 
0235    /*! Check the type of analysis the deep learning network can do */
0236    Bool_t HasAnalysisType(Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets);
0237 
0238    /*! Methods for training the deep learning network */
0239    void Train();
0240 
0241    Double_t GetMvaValue(Double_t *err = nullptr, Double_t *errUpper = nullptr);
0242    virtual const std::vector<Float_t>& GetRegressionValues();
0243    virtual const std::vector<Float_t>& GetMulticlassValues();
0244 
0245    /*! Methods for writing and reading weights */
0246    using MethodBase::ReadWeightsFromStream;
0247    void AddWeightsXMLTo(void *parent) const;
0248    void ReadWeightsFromXML(void *wghtnode);
0249    void ReadWeightsFromStream(std::istream &);
0250 
0251    /* Create ranking */
0252    const Ranking *CreateRanking();
0253 
0254    /* Getters */
0255    size_t GetInputDepth()  const { return fInputShape[1]; }   //< no. of channels for an image
0256    size_t GetInputHeight() const { return fInputShape[2]; }
0257    size_t GetInputWidth()  const { return fInputShape[3]; }
0258    size_t GetInputDim()    const { return fInputShape.size() - 2; }
0259    std::vector<size_t> GetInputShape() const { return fInputShape; }
0260 
0261    size_t GetBatchSize()   const { return fInputShape[0]; }
0262    size_t GetBatchDepth()  const { return fBatchDepth; }
0263    size_t GetBatchHeight() const { return fBatchHeight; }
0264    size_t GetBatchWidth()  const { return fBatchWidth; }
0265 
0266    const DeepNetImpl_t & GetDeepNet() const { return *fNet; }
0267 
0268    DNN::EInitialization GetWeightInitialization() const { return fWeightInitialization; }
0269    DNN::EOutputFunction GetOutputFunction()       const { return fOutputFunction; }
0270    DNN::ELossFunction GetLossFunction()           const { return fLossFunction; }
0271 
0272    TString GetInputLayoutString()          const { return fInputLayoutString; }
0273    TString GetBatchLayoutString()          const { return fBatchLayoutString; }
0274    TString GetLayoutString()               const { return fLayoutString; }
0275    TString GetErrorStrategyString()        const { return fErrorStrategy; }
0276    TString GetTrainingStrategyString()     const { return fTrainingStrategyString; }
0277    TString GetWeightInitializationString() const { return fWeightInitializationString; }
0278    TString GetArchitectureString()         const { return fArchitectureString; }
0279 
0280    const std::vector<TTrainingSettings> &GetTrainingSettings() const { return fTrainingSettings; }
0281    std::vector<TTrainingSettings>       &GetTrainingSettings()       { return fTrainingSettings; }
0282    const KeyValueVector_t               &GetKeyValueSettings() const { return fSettings; }
0283    KeyValueVector_t                     &GetKeyValueSettings()       { return fSettings; }
0284 
0285    /** Setters */
0286    void SetInputDepth (int inputDepth)  { fInputShape[1] = inputDepth; }
0287    void SetInputHeight(int inputHeight) { fInputShape[2] = inputHeight; }
0288    void SetInputWidth (int inputWidth)  { fInputShape[3] = inputWidth; }
0289    void SetInputShape (std::vector<size_t> inputShape) { fInputShape = std::move(inputShape); }
0290 
0291    void SetBatchSize  (size_t batchSize)   { fInputShape[0] = batchSize; }
0292    void SetBatchDepth (size_t batchDepth)  { fBatchDepth    = batchDepth; }
0293    void SetBatchHeight(size_t batchHeight) { fBatchHeight   = batchHeight; }
0294    void SetBatchWidth (size_t batchWidth)  { fBatchWidth    = batchWidth; }
0295 
0296    void SetWeightInitialization(DNN::EInitialization weightInitialization)
0297    {
0298       fWeightInitialization = weightInitialization;
0299    }
0300    void SetOutputFunction            (DNN::EOutputFunction outputFunction) { fOutputFunction = outputFunction; }
0301    void SetErrorStrategyString       (TString errorStrategy)               { fErrorStrategy = errorStrategy; }
0302    void SetTrainingStrategyString    (TString trainingStrategyString)      { fTrainingStrategyString = trainingStrategyString; }
0303    void SetWeightInitializationString(TString weightInitializationString)
0304    {
0305       fWeightInitializationString = weightInitializationString;
0306    }
0307    void SetArchitectureString        (TString architectureString)          { fArchitectureString = architectureString; }
0308    void SetLayoutString              (TString layoutString)                { fLayoutString = layoutString; }
0309 };
0310 
0311 } // namespace TMVA
0312 
0313 #endif