File indexing completed on 2025-01-18 10:10:55
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021
0022
0023
0024
0025
0026
0027 #ifndef TMVA_DNN_BatchNormLayer
0028 #define TMVA_DNN_BatchNormLayer
0029
0030 #include "TMVA/DNN/GeneralLayer.h"
0031 #include "TMVA/DNN/Functions.h"
0032
0033 #include "TMVA/DNN/Architectures/Reference.h"
0034
0035 #include "TMVA/DNN/CNN/ContextHandles.h"
0036
0037 #include <iostream>
0038 #include <iomanip>
0039 #include <vector>
0040
0041 namespace TMVA {
0042 namespace DNN {
0043
0044
0045
0046
0047
0048
0049
0050
0051
0052
0053
0054
0055
0056
0057
0058
0059
0060
0061
0062
0063 template <typename Architecture_t>
0064 class TBatchNormLayer : public VGeneralLayer<Architecture_t> {
0065 public:
0066
0067 using Scalar_t = typename Architecture_t::Scalar_t;
0068 using Matrix_t = typename Architecture_t::Matrix_t;
0069 using Tensor_t = typename Architecture_t::Tensor_t;
0070
0071 using HelperDescriptor_t = typename Architecture_t::TensorDescriptor_t;
0072 using BNormDescriptors_t = typename Architecture_t::BNormDescriptors_t;
0073
0074
0075 private:
0076
0077 Tensor_t fDerivatives;
0078
0079 int fNormAxis;
0080
0081 Scalar_t fMomentum;
0082 Scalar_t fEpsilon;
0083
0084 Matrix_t fMu;
0085 Matrix_t fVar;
0086 Matrix_t fIVar;
0087
0088 Matrix_t fMu_Training;
0089 Matrix_t fVar_Training;
0090
0091
0092 Tensor_t fReshapedData;
0093
0094
0095 int fTrainedBatches = 0;
0096
0097 TDescriptors * fDescriptors = nullptr;
0098
0099 public:
0100
0101 TBatchNormLayer(size_t batchSize, size_t inputDepth, size_t inputHeight, size_t inputWidth,
0102 const std::vector<size_t> & shape, int axis = -1, Scalar_t momentum = -1., Scalar_t epsilon = 0.0001);
0103
0104
0105 TBatchNormLayer(TBatchNormLayer<Architecture_t> *layer);
0106
0107
0108 TBatchNormLayer(const TBatchNormLayer &);
0109
0110
0111 ~TBatchNormLayer();
0112
0113
0114
0115
0116
0117
0118 void Forward(Tensor_t &input, bool inTraining = true);
0119
0120
0121
0122
0123
0124 void Backward(Tensor_t &gradients_backward, const Tensor_t &activations_backward);
0125
0126
0127
0128
0129 void ResetTraining() { fTrainedBatches = 0; }
0130
0131
0132 void Print() const;
0133
0134
0135 virtual void AddWeightsXMLTo(void *parent);
0136
0137
0138 virtual void ReadWeightsFromXML(void *parent);
0139
0140
0141 virtual void Initialize();
0142
0143
0144 const int & GetNTrainedBatches() const { return fTrainedBatches;}
0145 int & GetNTrainedBatches() { return fTrainedBatches;}
0146
0147
0148 const Matrix_t & GetBatchMean() const { return fMu;}
0149 Matrix_t & GetBatchMean() { return fMu;}
0150
0151
0152
0153
0154
0155
0156 const Matrix_t & GetVariance() const { return fVar;}
0157 Matrix_t & GetVariance() { return fVar;}
0158
0159
0160 const Matrix_t & GetIVariance() const { return fIVar;}
0161 Matrix_t & GetIVariance() { return fIVar;}
0162
0163
0164 const Matrix_t & GetMuVector() const { return fMu_Training;}
0165 Matrix_t & GetMuVector() { return fMu_Training;}
0166
0167
0168 const Matrix_t & GetVarVector() const { return fVar_Training;}
0169 Matrix_t & GetVarVector() { return fVar_Training;}
0170
0171
0172
0173
0174 Scalar_t GetMomentum() const { return fMomentum;}
0175
0176
0177 Scalar_t GetEpsilon() const { return fEpsilon;}
0178
0179
0180 Scalar_t GetNormAxis() const { return fNormAxis;}
0181
0182 const Matrix_t &GetReshapedData() const { return fReshapedData; }
0183 Matrix_t &GetReshapedData() { return fReshapedData; }
0184
0185 std::vector<Matrix_t> GetExtraLayerParameters() const {
0186 std::vector<Matrix_t> params(2);
0187 params[0] = this->GetMuVector();
0188 params[1] = this->GetVarVector();
0189 return params;
0190 }
0191
0192 void SetExtraLayerParameters(const std::vector<Matrix_t> & params)
0193 {
0194 this->GetMuVector() = params[0];
0195 this->GetVarVector() = params[1];
0196 }
0197
0198 protected:
0199 static size_t CalculateNormDim(int axis, size_t c, size_t h, size_t w)
0200 {
0201 if (axis == -1)
0202 return c * h * w;
0203 else if (axis == 1)
0204 return c;
0205 else if (axis == 2)
0206 return h;
0207 else if (axis == 3)
0208 return w;
0209 return 0;
0210 }
0211 };
0212
0213
0214
0215
0216
0217
0218 template <typename Architecture_t>
0219 TBatchNormLayer<Architecture_t>::TBatchNormLayer(size_t batchSize, size_t inputDepth, size_t inputHeight,
0220 size_t inputWidth, const std::vector<size_t> &shape, int axis,
0221 Scalar_t momentum, Scalar_t epsilon)
0222 : VGeneralLayer<Architecture_t>(batchSize, inputDepth, inputHeight, inputWidth,
0223 inputDepth, inputHeight, inputWidth,
0224 2, 1,
0225 CalculateNormDim(axis, inputDepth, inputHeight, inputWidth),
0226 1, 1, 1,
0227 shape[2], shape[0], shape[1],
0228 EInitialization::kZero),
0229 fNormAxis(axis), fMomentum(momentum), fEpsilon(epsilon),
0230 fMu(1, VGeneralLayer<Architecture_t>::GetWeightsAt(0).GetNcols()),
0231 fVar(1, VGeneralLayer<Architecture_t>::GetWeightsAt(0).GetNcols()),
0232 fIVar(1, VGeneralLayer<Architecture_t>::GetWeightsAt(0).GetNcols()),
0233 fMu_Training(1, VGeneralLayer<Architecture_t>::GetWeightsAt(0).GetNcols()),
0234 fVar_Training(1, VGeneralLayer<Architecture_t>::GetWeightsAt(0).GetNcols()),
0235 fReshapedData(1,1,1)
0236
0237 {
0238
0239 }
0240
0241 template <typename Architecture_t>
0242 TBatchNormLayer<Architecture_t>::TBatchNormLayer(TBatchNormLayer<Architecture_t> *layer)
0243 : VGeneralLayer<Architecture_t>(layer)
0244 {
0245
0246 printf("Error - copy ctor not implemented\n");
0247 }
0248
0249
0250 template <typename Architecture_t>
0251 TBatchNormLayer<Architecture_t>::TBatchNormLayer(const TBatchNormLayer &layer) : VGeneralLayer<Architecture_t>(layer)
0252 {
0253
0254 printf("Error - copy ctor not implemented\n");
0255 }
0256
0257
0258 template <typename Architecture_t>
0259 TBatchNormLayer<Architecture_t>::~TBatchNormLayer()
0260 {
0261
0262 if (fDescriptors) {
0263 Architecture_t::ReleaseBNormDescriptors(fDescriptors);
0264 delete fDescriptors;
0265 }
0266 }
0267
0268 template <typename Architecture_t>
0269 auto TBatchNormLayer<Architecture_t>::Initialize() -> void
0270 {
0271 Matrix_t &gamma = this->GetWeightsAt(0);
0272 Matrix_t &beta = this->GetWeightsAt(1);
0273 size_t bndim = gamma.GetNcols();
0274
0275 initialize<Architecture_t>(beta, EInitialization::kZero);
0276 for (size_t i = 0; i < bndim; ++i) {
0277 gamma(0, i) = 1.;
0278
0279 fMu_Training(0,i) = 0;
0280 fVar_Training(0,i) = 1;
0281 }
0282
0283 Matrix_t &dgamma = this->GetWeightGradientsAt(0);
0284 Matrix_t &dbeta = this->GetWeightGradientsAt(1);
0285 initialize<Architecture_t>(dgamma, EInitialization::kZero);
0286 initialize<Architecture_t>(dbeta, EInitialization::kZero);
0287
0288 fTrainedBatches = 0;
0289
0290 Architecture_t::InitializeBNormDescriptors(fDescriptors, this);
0291 }
0292
0293
0294 template <typename Architecture_t>
0295 auto TBatchNormLayer<Architecture_t>::Forward(Tensor_t &x, bool inTraining) -> void
0296 {
0297 Tensor_t x2;
0298 Tensor_t y2;
0299 if (x.GetLayout() != fReshapedData.GetLayout()) {
0300 x2 = Tensor_t(x.GetDeviceBuffer(), fReshapedData.GetShape(), fReshapedData.GetLayout());
0301 y2 = Tensor_t(this->GetOutput().GetDeviceBuffer(), fReshapedData.GetShape(), fReshapedData.GetLayout());
0302 }
0303 else{
0304 x2 = x;
0305 y2 = this->GetOutput();
0306 }
0307
0308 auto descr = static_cast<BNormDescriptors_t *> (fDescriptors);
0309 if (inTraining) {
0310 Architecture_t::BatchNormLayerForwardTraining(fNormAxis, x2, y2,
0311 this->GetWeightsAt(0), this->GetWeightsAt(1),
0312 this->GetBatchMean(), this->GetVariance(), this->GetIVariance(),
0313 this->GetMuVector(),
0314 this->GetVarVector(), this->GetNTrainedBatches(),
0315 this->GetMomentum(), this->GetEpsilon(),
0316 descr->HelperDescriptor);
0317 fTrainedBatches++;
0318 }
0319
0320 else {
0321
0322
0323
0324
0325
0326
0327 Architecture_t::BatchNormLayerForwardInference(fNormAxis, x2, this->GetWeightsAt(0), this->GetWeightsAt(1),
0328 y2, this->GetMuVector(), this->GetVarVector(),
0329 this->GetEpsilon(), descr->HelperDescriptor);
0330 fTrainedBatches = 0;
0331 }
0332
0333 }
0334
0335
0336 template <typename Architecture_t>
0337 auto TBatchNormLayer<Architecture_t>::Backward(Tensor_t &gradients_backward,
0338 const Tensor_t & activations_backward ) -> void
0339
0340 {
0341 auto descr = static_cast<BNormDescriptors_t *> (fDescriptors);
0342
0343
0344 if (activations_backward.GetLayout() != fReshapedData.GetLayout()) {
0345 Tensor_t x = Tensor_t(activations_backward.GetDeviceBuffer(), fReshapedData.GetShape(), fReshapedData.GetLayout());
0346 Tensor_t dx = Tensor_t(gradients_backward.GetDeviceBuffer(), fReshapedData.GetShape(), fReshapedData.GetLayout());
0347 Tensor_t dy = Tensor_t(this->GetActivationGradients().GetDeviceBuffer(), fReshapedData.GetShape(), fReshapedData.GetLayout());
0348
0349 Architecture_t::BatchNormLayerBackward(fNormAxis, x, dy, dx,
0350 this->GetWeightsAt(0),
0351 this->GetWeightGradientsAt(0), this->GetWeightGradientsAt(1),
0352 this->GetBatchMean(), this->GetVariance(), this->GetIVariance(),
0353 this->GetEpsilon(), descr->HelperDescriptor);
0354
0355 } else {
0356
0357 Architecture_t::BatchNormLayerBackward(fNormAxis, activations_backward,
0358 this->GetActivationGradients(),
0359 gradients_backward,
0360 this->GetWeightsAt(0),
0361 this->GetWeightGradientsAt(0), this->GetWeightGradientsAt(1),
0362 this->GetBatchMean(), this->GetVariance(), this->GetIVariance(),
0363 this->GetEpsilon(), descr->HelperDescriptor);
0364 }
0365 }
0366
0367
0368 template <typename Architecture_t>
0369 void TBatchNormLayer<Architecture_t>::Print() const
0370 {
0371 std::cout << " BATCH NORM Layer: \t";
0372 std::cout << " Input/Output = ( " ;
0373 auto &shape = this->GetOutput().GetShape();
0374 for (size_t i = 0; i < shape.size(); ++i) {
0375 if (i > 0) std::cout << " , ";
0376 std::cout << shape[i];
0377 }
0378 std::cout << " ) ";
0379 std::cout << "\t Norm dim =" << std::setw(6) << this->GetWeightsAt(0).GetNcols();
0380 std::cout << "\t axis = " << fNormAxis << std::endl;
0381 std::cout << std::endl;
0382 }
0383
0384
0385
0386 template <typename Architecture_t>
0387 void TBatchNormLayer<Architecture_t>::AddWeightsXMLTo(void *parent)
0388 {
0389
0390
0391
0392 auto layerxml = gTools().xmlengine().NewChild(parent, nullptr, "BatchNormLayer");
0393
0394
0395 gTools().AddAttr(layerxml, "Momentum", fMomentum);
0396 gTools().AddAttr(layerxml, "Epsilon", fEpsilon);
0397
0398
0399
0400
0401 this->WriteMatrixToXML(layerxml, "Training-mu", this->GetMuVector());
0402 this->WriteMatrixToXML(layerxml, "Training-variance", this->GetVarVector());
0403
0404
0405 this->WriteMatrixToXML(layerxml, "Gamma", this->GetWeightsAt(0));
0406 this->WriteMatrixToXML(layerxml, "Beta", this->GetWeightsAt(1));
0407
0408 }
0409
0410
0411 template <typename Architecture_t>
0412 void TBatchNormLayer<Architecture_t>::ReadWeightsFromXML(void *parent)
0413 {
0414
0415 gTools().ReadAttr(parent, "Momentum", fMomentum);
0416 gTools().ReadAttr(parent, "Epsilon", fEpsilon);
0417
0418
0419 this->ReadMatrixXML(parent, "Training-mu", this->GetMuVector());
0420 this->ReadMatrixXML(parent, "Training-variance", this->GetVarVector());
0421
0422 this->ReadMatrixXML(parent, "Gamma", this->GetWeightsAt(0));
0423 this->ReadMatrixXML(parent, "Beta", this->GetWeightsAt(1));
0424 }
0425
0426 }
0427 }
0428
0429 #endif