File indexing completed on 2025-10-31 09:16:08
0001 
0002 
0003 
0004 
0005 
0006 
0007 
0008 
0009 
0010 
0011 
0012 
0013 
0014 
0015 
0016 
0017 
0018 
0019 
0020 
0021 
0022 
0023 
0024 
0025 
0026 
0027 #ifndef TMVA_DNN_BatchNormLayer
0028 #define TMVA_DNN_BatchNormLayer
0029 
0030 #include "TMVA/DNN/GeneralLayer.h"
0031 #include "TMVA/DNN/Functions.h"
0032 
0033 #include "TMVA/DNN/Architectures/Reference.h"
0034 
0035 #include "TMVA/DNN/CNN/ContextHandles.h"
0036 
0037 #include <iostream>
0038 #include <iomanip>
0039 #include <vector>
0040 
0041 namespace TMVA {
0042 namespace DNN {
0043 
0044 
0045 
0046 
0047 
0048 
0049 
0050 
0051 
0052 
0053 
0054 
0055 
0056 
0057 
0058 
0059 
0060 
0061 
0062 
0063 template <typename Architecture_t>
0064 class TBatchNormLayer : public VGeneralLayer<Architecture_t> {
0065 public:
0066 
0067    using Scalar_t = typename Architecture_t::Scalar_t;
0068    using Matrix_t = typename Architecture_t::Matrix_t;
0069    using Tensor_t = typename Architecture_t::Tensor_t;
0070 
0071    using HelperDescriptor_t  = typename Architecture_t::TensorDescriptor_t;
0072    using BNormDescriptors_t = typename Architecture_t::BNormDescriptors_t;
0073 
0074 
0075 private:
0076 
0077    Tensor_t fDerivatives; 
0078 
0079    int      fNormAxis; 
0080 
0081    Scalar_t fMomentum; 
0082    Scalar_t fEpsilon;
0083 
0084    Matrix_t fMu;
0085    Matrix_t fVar;
0086    Matrix_t fIVar;
0087 
0088    Matrix_t fMu_Training;
0089    Matrix_t fVar_Training;
0090 
0091    
0092    Tensor_t fReshapedData;  
0093 
0094    
0095    int fTrainedBatches = 0;
0096 
0097    TDescriptors * fDescriptors = nullptr;
0098 
0099 public:
0100    
0101    TBatchNormLayer(size_t batchSize, size_t inputDepth, size_t inputHeight, size_t inputWidth,
0102                    const std::vector<size_t> & shape, int axis = -1, Scalar_t momentum = -1., Scalar_t epsilon = 0.0001);
0103 
0104    
0105    TBatchNormLayer(TBatchNormLayer<Architecture_t> *layer);
0106 
0107    
0108    TBatchNormLayer(const TBatchNormLayer &);
0109 
0110    
0111    ~TBatchNormLayer();
0112 
0113    
0114 
0115 
0116 
0117 
0118    void Forward(Tensor_t &input, bool inTraining = true);
0119 
0120    
0121 
0122 
0123 
0124    void Backward(Tensor_t &gradients_backward, const Tensor_t &activations_backward);
0125    
0126 
0127 
0128    
0129    void ResetTraining() { fTrainedBatches = 0; }
0130 
0131    
0132    void Print() const;
0133 
0134    
0135    virtual void AddWeightsXMLTo(void *parent);
0136 
0137    
0138    virtual void ReadWeightsFromXML(void *parent);
0139 
0140    
0141    virtual void Initialize();
0142 
0143    
0144    const int & GetNTrainedBatches() const { return fTrainedBatches;}
0145    int & GetNTrainedBatches() { return fTrainedBatches;}
0146 
0147    
0148    const Matrix_t & GetBatchMean() const { return fMu;}
0149    Matrix_t & GetBatchMean() { return fMu;}
0150 
0151    
0152    
0153    
0154 
0155    
0156    const Matrix_t & GetVariance() const { return fVar;}
0157    Matrix_t & GetVariance() { return fVar;}
0158 
0159    
0160    const Matrix_t & GetIVariance() const { return fIVar;}
0161    Matrix_t & GetIVariance() { return fIVar;}
0162 
0163    
0164    const Matrix_t & GetMuVector() const { return fMu_Training;}
0165    Matrix_t & GetMuVector() { return fMu_Training;}
0166 
0167    
0168    const Matrix_t & GetVarVector() const { return fVar_Training;}
0169    Matrix_t & GetVarVector()  { return fVar_Training;}
0170 
0171    
0172 
0173    
0174    Scalar_t GetMomentum() const { return fMomentum;}
0175 
0176    
0177    Scalar_t GetEpsilon() const { return fEpsilon;}
0178 
0179    
0180    Scalar_t GetNormAxis() const { return fNormAxis;}
0181 
0182    const Matrix_t &GetReshapedData() const { return fReshapedData; }
0183    Matrix_t &GetReshapedData() { return fReshapedData; }
0184 
0185    std::vector<Matrix_t> GetExtraLayerParameters() const {
0186       std::vector<Matrix_t> params(2);
0187       params[0] = this->GetMuVector();
0188       params[1] = this->GetVarVector();
0189       return params;
0190    }
0191 
0192    void SetExtraLayerParameters(const std::vector<Matrix_t> & params)
0193    {
0194       this->GetMuVector() = params[0];
0195       this->GetVarVector() = params[1];
0196    }
0197 
0198 protected:
0199    static size_t CalculateNormDim(int axis, size_t c, size_t h, size_t w)
0200    {
0201       if (axis == -1)
0202          return c * h * w;
0203       else if (axis == 1)
0204          return c;
0205       else if (axis == 2)
0206          return h;
0207       else if (axis == 3)
0208          return w;
0209       return 0;
0210       }
0211 };
0212 
0213 
0214 
0215 
0216 
0217 
0218 template <typename Architecture_t>
0219 TBatchNormLayer<Architecture_t>::TBatchNormLayer(size_t batchSize, size_t inputDepth, size_t inputHeight,
0220                                                  size_t inputWidth, const std::vector<size_t> &shape, int axis,
0221                                                  Scalar_t momentum, Scalar_t epsilon)
0222    : VGeneralLayer<Architecture_t>(batchSize, inputDepth, inputHeight, inputWidth, 
0223                                    inputDepth, inputHeight, inputWidth,            
0224                                    2, 1,
0225                                    CalculateNormDim(axis, inputDepth, inputHeight, inputWidth), 
0226                                    1, 1, 1,                                                      
0227                                    shape[2], shape[0], shape[1],                                 
0228                                    EInitialization::kZero),
0229      fNormAxis(axis), fMomentum(momentum), fEpsilon(epsilon),
0230      fMu(1, VGeneralLayer<Architecture_t>::GetWeightsAt(0).GetNcols()), 
0231      fVar(1, VGeneralLayer<Architecture_t>::GetWeightsAt(0).GetNcols()),
0232      fIVar(1, VGeneralLayer<Architecture_t>::GetWeightsAt(0).GetNcols()),
0233      fMu_Training(1, VGeneralLayer<Architecture_t>::GetWeightsAt(0).GetNcols()),
0234      fVar_Training(1, VGeneralLayer<Architecture_t>::GetWeightsAt(0).GetNcols()),
0235      fReshapedData(1,1,1)  
0236 
0237 {
0238 
0239 }
0240 
0241 template <typename Architecture_t>
0242 TBatchNormLayer<Architecture_t>::TBatchNormLayer(TBatchNormLayer<Architecture_t> *layer)
0243    : VGeneralLayer<Architecture_t>(layer)
0244 {
0245    
0246    printf("Error - copy ctor not implemented\n");
0247 }
0248 
0249 
0250 template <typename Architecture_t>
0251 TBatchNormLayer<Architecture_t>::TBatchNormLayer(const TBatchNormLayer &layer) : VGeneralLayer<Architecture_t>(layer)
0252 {
0253    
0254    printf("Error - copy ctor not implemented\n");
0255 }
0256 
0257 
0258 template <typename Architecture_t>
0259 TBatchNormLayer<Architecture_t>::~TBatchNormLayer()
0260 {
0261    
0262    if (fDescriptors) {
0263       Architecture_t::ReleaseBNormDescriptors(fDescriptors);
0264       delete fDescriptors;
0265    }
0266 }
0267 
0268 template <typename Architecture_t>
0269 auto TBatchNormLayer<Architecture_t>::Initialize() -> void
0270 {
0271    Matrix_t &gamma = this->GetWeightsAt(0);
0272    Matrix_t &beta = this->GetWeightsAt(1);
0273    size_t bndim = gamma.GetNcols();
0274 
0275    initialize<Architecture_t>(beta, EInitialization::kZero);
0276    for (size_t i = 0; i < bndim; ++i) {
0277       gamma(0, i) = 1.;
0278       
0279       fMu_Training(0,i) = 0;
0280       fVar_Training(0,i) = 1;
0281    }
0282 
0283    Matrix_t &dgamma = this->GetWeightGradientsAt(0);
0284    Matrix_t &dbeta = this->GetWeightGradientsAt(1);
0285    initialize<Architecture_t>(dgamma, EInitialization::kZero);
0286    initialize<Architecture_t>(dbeta, EInitialization::kZero);
0287 
0288    fTrainedBatches = 0;
0289 
0290    Architecture_t::InitializeBNormDescriptors(fDescriptors, this);
0291 }
0292 
0293 
0294 template <typename Architecture_t>
0295 auto TBatchNormLayer<Architecture_t>::Forward(Tensor_t &x, bool inTraining) -> void
0296 {
0297    Tensor_t x2;
0298    Tensor_t y2;
0299    if (x.GetLayout() != fReshapedData.GetLayout()) {
0300       x2 = Tensor_t(x.GetDeviceBuffer(), fReshapedData.GetShape(), fReshapedData.GetLayout());
0301       y2 = Tensor_t(this->GetOutput().GetDeviceBuffer(), fReshapedData.GetShape(), fReshapedData.GetLayout());
0302    }
0303    else{
0304       x2 = x;
0305       y2 = this->GetOutput();
0306    }
0307 
0308    auto descr = static_cast<BNormDescriptors_t *> (fDescriptors);
0309    if (inTraining) {
0310       Architecture_t::BatchNormLayerForwardTraining(fNormAxis, x2, y2,
0311                                                     this->GetWeightsAt(0), this->GetWeightsAt(1),
0312                                                     this->GetBatchMean(), this->GetVariance(), this->GetIVariance(),
0313                                                     this->GetMuVector(),
0314                                                     this->GetVarVector(), this->GetNTrainedBatches(),
0315                                                     this->GetMomentum(), this->GetEpsilon(),
0316                                                     descr->HelperDescriptor);
0317       fTrainedBatches++;
0318    }
0319 
0320    else {
0321       
0322       
0323       
0324       
0325       
0326       
0327       Architecture_t::BatchNormLayerForwardInference(fNormAxis, x2, this->GetWeightsAt(0), this->GetWeightsAt(1),
0328                                                      y2, this->GetMuVector(), this->GetVarVector(),
0329                                                      this->GetEpsilon(), descr->HelperDescriptor);
0330       fTrainedBatches = 0;
0331    }
0332 
0333 }
0334 
0335 
0336 template <typename Architecture_t>
0337 auto TBatchNormLayer<Architecture_t>::Backward(Tensor_t &gradients_backward,
0338                                                const Tensor_t & activations_backward ) -> void
0339 
0340 {
0341    auto descr = static_cast<BNormDescriptors_t *> (fDescriptors);
0342 
0343 
0344    if (activations_backward.GetLayout() != fReshapedData.GetLayout()) {
0345       Tensor_t x = Tensor_t(activations_backward.GetDeviceBuffer(), fReshapedData.GetShape(), fReshapedData.GetLayout());
0346       Tensor_t dx = Tensor_t(gradients_backward.GetDeviceBuffer(), fReshapedData.GetShape(), fReshapedData.GetLayout());
0347       Tensor_t dy = Tensor_t(this->GetActivationGradients().GetDeviceBuffer(), fReshapedData.GetShape(), fReshapedData.GetLayout());
0348 
0349       Architecture_t::BatchNormLayerBackward(fNormAxis, x, dy, dx,
0350                                              this->GetWeightsAt(0),           
0351                                              this->GetWeightGradientsAt(0), this->GetWeightGradientsAt(1),
0352                                              this->GetBatchMean(), this->GetVariance(), this->GetIVariance(),
0353                                              this->GetEpsilon(), descr->HelperDescriptor);
0354 
0355    } else {
0356 
0357       Architecture_t::BatchNormLayerBackward(fNormAxis, activations_backward, 
0358                                           this->GetActivationGradients(), 
0359                                           gradients_backward,             
0360                                           this->GetWeightsAt(0),          
0361                                           this->GetWeightGradientsAt(0), this->GetWeightGradientsAt(1),
0362                                           this->GetBatchMean(), this->GetVariance(), this->GetIVariance(),
0363                                           this->GetEpsilon(), descr->HelperDescriptor);
0364    }
0365 }
0366 
0367 
0368 template <typename Architecture_t>
0369 void TBatchNormLayer<Architecture_t>::Print() const
0370 {
0371    std::cout << " BATCH NORM Layer: \t";
0372    std::cout << " Input/Output = ( " ;
0373    auto &shape = this->GetOutput().GetShape();
0374    for (size_t i = 0; i < shape.size(); ++i) {
0375       if (i > 0) std::cout << " , ";
0376       std::cout << shape[i];
0377    }
0378    std::cout  << " ) ";
0379    std::cout << "\t Norm dim =" << std::setw(6) << this->GetWeightsAt(0).GetNcols();
0380    std::cout << "\t axis = " << fNormAxis << std::endl;
0381    std::cout << std::endl;
0382 }
0383 
0384 
0385 
0386 template <typename Architecture_t>
0387 void TBatchNormLayer<Architecture_t>::AddWeightsXMLTo(void *parent)
0388 {
0389 
0390    
0391 
0392    auto layerxml = gTools().xmlengine().NewChild(parent, nullptr, "BatchNormLayer");
0393 
0394 
0395    gTools().AddAttr(layerxml, "Momentum", fMomentum);
0396    gTools().AddAttr(layerxml, "Epsilon", fEpsilon);
0397 
0398    
0399    
0400 
0401    this->WriteMatrixToXML(layerxml, "Training-mu", this->GetMuVector());
0402    this->WriteMatrixToXML(layerxml, "Training-variance", this->GetVarVector());
0403 
0404    
0405    this->WriteMatrixToXML(layerxml, "Gamma", this->GetWeightsAt(0));
0406    this->WriteMatrixToXML(layerxml, "Beta", this->GetWeightsAt(1));
0407 
0408 }
0409 
0410 
0411 template <typename Architecture_t>
0412 void TBatchNormLayer<Architecture_t>::ReadWeightsFromXML(void *parent)
0413 {
0414    
0415    gTools().ReadAttr(parent, "Momentum", fMomentum);
0416    gTools().ReadAttr(parent, "Epsilon", fEpsilon);
0417    
0418 
0419    this->ReadMatrixXML(parent, "Training-mu", this->GetMuVector());
0420    this->ReadMatrixXML(parent, "Training-variance", this->GetVarVector());
0421 
0422    this->ReadMatrixXML(parent, "Gamma", this->GetWeightsAt(0));
0423    this->ReadMatrixXML(parent, "Beta", this->GetWeightsAt(1));
0424 }
0425 
0426 } 
0427 } 
0428 
0429 #endif