Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-18 10:10:57

0001 // @(#)root/tmva/tmva/dnn:$Id$
0002 // Author: Ravi Kiran S
0003 
0004 /**********************************************************************************
0005  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis       *
0006  * Package: TMVA                                                                  *
0007  * Class  : VOptimizer                                                            *
0008  *                                             *
0009  *                                                                                *
0010  * Description:                                                                   *
0011  *      General Optimizer Class                                                   *
0012  *                                                                                *
0013  * Authors (alphabetical):                                                        *
0014  *      Ravi Kiran S      <sravikiran0606@gmail.com>  - CERN, Switzerland         *
0015  *                                                                                *
0016  * Copyright (c) 2005-2018 :                                                      *
0017  *      CERN, Switzerland                                                         *
0018  *      U. of Victoria, Canada                                                    *
0019  *      MPI-K Heidelberg, Germany                                                 *
0020  *      U. of Bonn, Germany                                                       *
0021  *                                                                                *
0022  * Redistribution and use in source and binary forms, with or without             *
0023  * modification, are permitted according to the terms listed in LICENSE           *
0024  * (see tmva/doc/LICENSE)                                          *
0025  **********************************************************************************/
0026 
0027 #ifndef TMVA_DNN_OPTIMIZER
0028 #define TMVA_DNN_OPTIMIZER
0029 
0030 #include "TMVA/DNN/GeneralLayer.h"
0031 #include "TMVA/DNN/DeepNet.h"
0032 #include <vector>
0033 
0034 namespace TMVA {
0035 namespace DNN {
0036 
0037 /** \class VOptimizer
0038     Generic Optimizer class
0039 
0040     This class represents the general class for all optimizers in the Deep Learning
0041     Module.
0042  */
0043 template <typename Architecture_t, typename Layer_t = VGeneralLayer<Architecture_t>,
0044           typename DeepNet_t = TDeepNet<Architecture_t, Layer_t>>
0045 class VOptimizer {
0046 public:
0047    using Matrix_t = typename Architecture_t::Matrix_t;
0048    using Scalar_t = typename Architecture_t::Scalar_t;
0049 
0050 protected:
0051    Scalar_t fLearningRate; ///< The learning rate used for training.
0052    size_t fGlobalStep;     ///< The current global step count during training.
0053    DeepNet_t &fDeepNet;    ///< The reference to the deep net.
0054 
0055    /*! Update the weights, given the current weight gradients. */
0056    virtual void
0057    UpdateWeights(size_t layerIndex, std::vector<Matrix_t> &weights, const std::vector<Matrix_t> &weightGradients) = 0;
0058 
0059    /*! Update the biases, given the current bias gradients. */
0060    virtual void
0061    UpdateBiases(size_t layerIndex, std::vector<Matrix_t> &biases, const std::vector<Matrix_t> &biasGradients) = 0;
0062 
0063 public:
0064    /*! Constructor. */
0065    VOptimizer(Scalar_t learningRate, DeepNet_t &deepNet);
0066 
0067    /*! Performs one step of optimization. */
0068    void Step();
0069 
0070    /*! Virtual Destructor. */
0071    virtual ~VOptimizer() = default;
0072 
0073    /*! Increments the global step. */
0074    void IncrementGlobalStep() { this->fGlobalStep++; }
0075 
0076    /*! Getters */
0077    Scalar_t GetLearningRate() const
0078    {
0079       return fLearningRate;
0080    }
0081    size_t GetGlobalStep() const { return fGlobalStep; }
0082    std::vector<Layer_t *> &GetLayers() { return fDeepNet.GetLayers(); }
0083    Layer_t *GetLayerAt(size_t i) { return fDeepNet.GetLayerAt(i); }
0084 
0085    /*! Setters */
0086    void SetLearningRate(size_t learningRate) { fLearningRate = learningRate; }
0087 };
0088 
0089 //
0090 //
0091 //  The General Optimizer Class - Implementation
0092 //_________________________________________________________________________________________________
0093 template <typename Architecture_t, typename Layer_t, typename DeepNet_t>
0094 VOptimizer<Architecture_t, Layer_t, DeepNet_t>::VOptimizer(Scalar_t learningRate, DeepNet_t &deepNet)
0095    : fLearningRate(learningRate), fGlobalStep(0), fDeepNet(deepNet)
0096 {
0097 }
0098 
0099 //_________________________________________________________________________________________________
0100 template <typename Architecture_t, typename Layer_t, typename DeepNet_t>
0101 auto VOptimizer<Architecture_t, Layer_t, DeepNet_t>::Step() -> void
0102 {
0103    for (size_t i = 0; i < this->GetLayers().size(); i++) {
0104       this->UpdateWeights(i, this->GetLayerAt(i)->GetWeights(), this->GetLayerAt(i)->GetWeightGradients());
0105       this->UpdateBiases(i, this->GetLayerAt(i)->GetBiases(), this->GetLayerAt(i)->GetBiasGradients());
0106    }
0107 }
0108 
0109 } // namespace DNN
0110 } // namespace TMVA
0111 
0112 #endif