Warning, file /include/root/TMVA/DNN/BatchNormLayer.h was not indexed
or was modified since last indexation (in which case cross-reference links may be missing, inaccurate or erroneous).
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021
0022
0023
0024
0025
0026
0027 #ifndef TMVA_DNN_BatchNormLayer
0028 #define TMVA_DNN_BatchNormLayer
0029
0030 #include "TMVA/DNN/GeneralLayer.h"
0031 #include "TMVA/DNN/Functions.h"
0032
0033 #include "TMVA/DNN/Architectures/Reference.h"
0034
0035 #include "TMVA/DNN/CNN/ContextHandles.h"
0036
0037 #include <iostream>
0038 #include <iomanip>
0039 #include <vector>
0040
0041 namespace TMVA {
0042 namespace DNN {
0043
0044
0045
0046
0047
0048
0049
0050
0051
0052
0053
0054
0055
0056
0057
0058
0059
0060
0061
0062
0063 template <typename Architecture_t>
0064 class TBatchNormLayer : public VGeneralLayer<Architecture_t> {
0065 public:
0066
0067 using Scalar_t = typename Architecture_t::Scalar_t;
0068 using Matrix_t = typename Architecture_t::Matrix_t;
0069 using Tensor_t = typename Architecture_t::Tensor_t;
0070
0071 using HelperDescriptor_t = typename Architecture_t::TensorDescriptor_t;
0072 using BNormDescriptors_t = typename Architecture_t::BNormDescriptors_t;
0073
0074
0075 private:
0076
0077 Tensor_t fDerivatives;
0078
0079 int fNormAxis;
0080
0081 Scalar_t fMomentum;
0082 Scalar_t fEpsilon;
0083
0084 Matrix_t fMu;
0085 Matrix_t fVar;
0086 Matrix_t fIVar;
0087
0088 Matrix_t fMu_Training;
0089 Matrix_t fVar_Training;
0090
0091
0092 Tensor_t fReshapedData;
0093
0094
0095 int fTrainedBatches = 0;
0096
0097 TDescriptors * fDescriptors = nullptr;
0098
0099 public:
0100
0101 TBatchNormLayer(size_t batchSize, size_t inputDepth, size_t inputHeight, size_t inputWidth,
0102 const std::vector<size_t> & shape, int axis = -1, Scalar_t momentum = -1., Scalar_t epsilon = 0.0001);
0103
0104
0105 TBatchNormLayer(TBatchNormLayer<Architecture_t> *layer);
0106
0107
0108 TBatchNormLayer(const TBatchNormLayer &);
0109
0110
0111 ~TBatchNormLayer();
0112
0113
0114
0115
0116
0117
0118 void Forward(Tensor_t &input, bool inTraining = true);
0119
0120
0121
0122
0123
0124 void Backward(Tensor_t &gradients_backward, const Tensor_t &activations_backward);
0125
0126
0127
0128
0129 void ResetTraining() { fTrainedBatches = 0; }
0130
0131
0132 void Print() const;
0133
0134
0135 virtual void AddWeightsXMLTo(void *parent);
0136
0137
0138 virtual void ReadWeightsFromXML(void *parent);
0139
0140
0141 virtual void Initialize();
0142
0143
0144 const int & GetNTrainedBatches() const { return fTrainedBatches;}
0145 int & GetNTrainedBatches() { return fTrainedBatches;}
0146
0147
0148 const Matrix_t & GetBatchMean() const { return fMu;}
0149 Matrix_t & GetBatchMean() { return fMu;}
0150
0151
0152
0153
0154
0155
0156 const Matrix_t & GetVariance() const { return fVar;}
0157 Matrix_t & GetVariance() { return fVar;}
0158
0159
0160 const Matrix_t & GetIVariance() const { return fIVar;}
0161 Matrix_t & GetIVariance() { return fIVar;}
0162
0163
0164 const Matrix_t & GetMuVector() const { return fMu_Training;}
0165 Matrix_t & GetMuVector() { return fMu_Training;}
0166
0167
0168 const Matrix_t & GetVarVector() const { return fVar_Training;}
0169 Matrix_t & GetVarVector() { return fVar_Training;}
0170
0171
0172
0173
0174 Scalar_t GetMomentum() const { return fMomentum;}
0175
0176
0177 Scalar_t GetEpsilon() const { return fEpsilon;}
0178
0179
0180 Scalar_t GetNormAxis() const { return fNormAxis;}
0181
0182 const Matrix_t &GetReshapedData() const { return fReshapedData; }
0183 Matrix_t &GetReshapedData() { return fReshapedData; }
0184
0185 std::vector<Matrix_t> GetExtraLayerParameters() const {
0186 std::vector<Matrix_t> params(2);
0187 params[0] = this->GetMuVector();
0188 params[1] = this->GetVarVector();
0189 return params;
0190 }
0191
0192 void SetExtraLayerParameters(const std::vector<Matrix_t> & params)
0193 {
0194 this->GetMuVector() = params[0];
0195 this->GetVarVector() = params[1];
0196 }
0197
0198 protected:
0199 static size_t CalculateNormDim(int axis, size_t c, size_t h, size_t w)
0200 {
0201 if (axis == -1)
0202 return c * h * w;
0203 else if (axis == 1)
0204 return c;
0205 else if (axis == 2)
0206 return h;
0207 else if (axis == 3)
0208 return w;
0209 return 0;
0210 }
0211 };
0212
0213
0214
0215
0216
0217
0218 template <typename Architecture_t>
0219 TBatchNormLayer<Architecture_t>::TBatchNormLayer(size_t batchSize, size_t inputDepth, size_t inputHeight,
0220 size_t inputWidth, const std::vector<size_t> &shape, int axis,
0221 Scalar_t momentum, Scalar_t epsilon)
0222 : VGeneralLayer<Architecture_t>(batchSize, inputDepth, inputHeight, inputWidth,
0223 inputDepth, inputHeight, inputWidth,
0224 2, 1,
0225 CalculateNormDim(axis, inputDepth, inputHeight, inputWidth),
0226 1, 1, 1,
0227 shape[2], shape[0], shape[1],
0228 EInitialization::kZero),
0229 fNormAxis(axis), fMomentum(momentum), fEpsilon(epsilon),
0230 fMu(1, VGeneralLayer<Architecture_t>::GetWeightsAt(0).GetNcols()),
0231 fVar(1, VGeneralLayer<Architecture_t>::GetWeightsAt(0).GetNcols()),
0232 fIVar(1, VGeneralLayer<Architecture_t>::GetWeightsAt(0).GetNcols()),
0233 fMu_Training(1, VGeneralLayer<Architecture_t>::GetWeightsAt(0).GetNcols()),
0234 fVar_Training(1, VGeneralLayer<Architecture_t>::GetWeightsAt(0).GetNcols()),
0235 fReshapedData(1,1,1)
0236
0237 {
0238
0239 }
0240
0241 template <typename Architecture_t>
0242 TBatchNormLayer<Architecture_t>::TBatchNormLayer(TBatchNormLayer<Architecture_t> *layer)
0243 : VGeneralLayer<Architecture_t>(layer)
0244 {
0245
0246 printf("Error - copy ctor not implemented\n");
0247 }
0248
0249
0250 template <typename Architecture_t>
0251 TBatchNormLayer<Architecture_t>::TBatchNormLayer(const TBatchNormLayer &layer) : VGeneralLayer<Architecture_t>(layer)
0252 {
0253
0254 printf("Error - copy ctor not implemented\n");
0255 }
0256
0257
0258 template <typename Architecture_t>
0259 TBatchNormLayer<Architecture_t>::~TBatchNormLayer()
0260 {
0261
0262 if (fDescriptors) {
0263 Architecture_t::ReleaseBNormDescriptors(fDescriptors);
0264 delete fDescriptors;
0265 }
0266 }
0267
0268 template <typename Architecture_t>
0269 auto TBatchNormLayer<Architecture_t>::Initialize() -> void
0270 {
0271 Matrix_t &gamma = this->GetWeightsAt(0);
0272 Matrix_t &beta = this->GetWeightsAt(1);
0273 size_t bndim = gamma.GetNcols();
0274
0275 initialize<Architecture_t>(beta, EInitialization::kZero);
0276 for (size_t i = 0; i < bndim; ++i) {
0277 gamma(0, i) = 1.;
0278
0279 fMu_Training(0,i) = 0;
0280 fVar_Training(0,i) = 1;
0281 }
0282
0283 Matrix_t &dgamma = this->GetWeightGradientsAt(0);
0284 Matrix_t &dbeta = this->GetWeightGradientsAt(1);
0285 initialize<Architecture_t>(dgamma, EInitialization::kZero);
0286 initialize<Architecture_t>(dbeta, EInitialization::kZero);
0287
0288 fTrainedBatches = 0;
0289
0290 Architecture_t::InitializeBNormDescriptors(fDescriptors, this);
0291 }
0292
0293
0294 template <typename Architecture_t>
0295 auto TBatchNormLayer<Architecture_t>::Forward(Tensor_t &x, bool inTraining) -> void
0296 {
0297 Tensor_t x2;
0298 Tensor_t y2;
0299 if (x.GetLayout() != fReshapedData.GetLayout()) {
0300 x2 = Tensor_t(x.GetDeviceBuffer(), fReshapedData.GetShape(), fReshapedData.GetLayout());
0301 y2 = Tensor_t(this->GetOutput().GetDeviceBuffer(), fReshapedData.GetShape(), fReshapedData.GetLayout());
0302 }
0303 else{
0304 x2 = x;
0305 y2 = this->GetOutput();
0306 }
0307
0308 auto descr = static_cast<BNormDescriptors_t *> (fDescriptors);
0309 if (inTraining) {
0310 Architecture_t::BatchNormLayerForwardTraining(fNormAxis, x2, y2,
0311 this->GetWeightsAt(0), this->GetWeightsAt(1),
0312 this->GetBatchMean(), this->GetVariance(), this->GetIVariance(),
0313 this->GetMuVector(),
0314 this->GetVarVector(), this->GetNTrainedBatches(),
0315 this->GetMomentum(), this->GetEpsilon(),
0316 descr->HelperDescriptor);
0317 fTrainedBatches++;
0318 }
0319
0320 else {
0321
0322
0323
0324
0325
0326
0327 Architecture_t::BatchNormLayerForwardInference(fNormAxis, x2, this->GetWeightsAt(0), this->GetWeightsAt(1),
0328 y2, this->GetMuVector(), this->GetVarVector(),
0329 this->GetEpsilon(), descr->HelperDescriptor);
0330 fTrainedBatches = 0;
0331 }
0332
0333 }
0334
0335
0336 template <typename Architecture_t>
0337 auto TBatchNormLayer<Architecture_t>::Backward(Tensor_t &gradients_backward,
0338 const Tensor_t & activations_backward ) -> void
0339
0340 {
0341 auto descr = static_cast<BNormDescriptors_t *> (fDescriptors);
0342
0343
0344 if (activations_backward.GetLayout() != fReshapedData.GetLayout()) {
0345 Tensor_t x = Tensor_t(activations_backward.GetDeviceBuffer(), fReshapedData.GetShape(), fReshapedData.GetLayout());
0346 Tensor_t dx = Tensor_t(gradients_backward.GetDeviceBuffer(), fReshapedData.GetShape(), fReshapedData.GetLayout());
0347 Tensor_t dy = Tensor_t(this->GetActivationGradients().GetDeviceBuffer(), fReshapedData.GetShape(), fReshapedData.GetLayout());
0348
0349 Architecture_t::BatchNormLayerBackward(fNormAxis, x, dy, dx,
0350 this->GetWeightsAt(0),
0351 this->GetWeightGradientsAt(0), this->GetWeightGradientsAt(1),
0352 this->GetBatchMean(), this->GetVariance(), this->GetIVariance(),
0353 this->GetEpsilon(), descr->HelperDescriptor);
0354
0355 } else {
0356
0357 Architecture_t::BatchNormLayerBackward(fNormAxis, activations_backward,
0358 this->GetActivationGradients(),
0359 gradients_backward,
0360 this->GetWeightsAt(0),
0361 this->GetWeightGradientsAt(0), this->GetWeightGradientsAt(1),
0362 this->GetBatchMean(), this->GetVariance(), this->GetIVariance(),
0363 this->GetEpsilon(), descr->HelperDescriptor);
0364 }
0365 }
0366
0367
0368 template <typename Architecture_t>
0369 void TBatchNormLayer<Architecture_t>::Print() const
0370 {
0371 std::cout << " BATCH NORM Layer: \t";
0372 std::cout << " Input/Output = ( " ;
0373 auto &shape = this->GetOutput().GetShape();
0374 for (size_t i = 0; i < shape.size(); ++i) {
0375 if (i > 0) std::cout << " , ";
0376 std::cout << shape[i];
0377 }
0378 std::cout << " ) ";
0379 std::cout << "\t Norm dim =" << std::setw(6) << this->GetWeightsAt(0).GetNcols();
0380 std::cout << "\t axis = " << fNormAxis << std::endl;
0381 std::cout << std::endl;
0382 }
0383
0384
0385
0386 template <typename Architecture_t>
0387 void TBatchNormLayer<Architecture_t>::AddWeightsXMLTo(void *parent)
0388 {
0389
0390
0391
0392 auto layerxml = gTools().xmlengine().NewChild(parent, nullptr, "BatchNormLayer");
0393
0394
0395 gTools().AddAttr(layerxml, "Momentum", fMomentum);
0396 gTools().AddAttr(layerxml, "Epsilon", fEpsilon);
0397
0398
0399
0400
0401 this->WriteMatrixToXML(layerxml, "Training-mu", this->GetMuVector());
0402 this->WriteMatrixToXML(layerxml, "Training-variance", this->GetVarVector());
0403
0404
0405 this->WriteMatrixToXML(layerxml, "Gamma", this->GetWeightsAt(0));
0406 this->WriteMatrixToXML(layerxml, "Beta", this->GetWeightsAt(1));
0407
0408 }
0409
0410
0411 template <typename Architecture_t>
0412 void TBatchNormLayer<Architecture_t>::ReadWeightsFromXML(void *parent)
0413 {
0414
0415 gTools().ReadAttr(parent, "Momentum", fMomentum);
0416 gTools().ReadAttr(parent, "Epsilon", fEpsilon);
0417
0418
0419 this->ReadMatrixXML(parent, "Training-mu", this->GetMuVector());
0420 this->ReadMatrixXML(parent, "Training-variance", this->GetVarVector());
0421
0422 this->ReadMatrixXML(parent, "Gamma", this->GetWeightsAt(0));
0423 this->ReadMatrixXML(parent, "Beta", this->GetWeightsAt(1));
0424 }
0425
0426 }
0427 }
0428
0429 #endif