File indexing completed on 2025-01-18 10:10:54
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021
0022
0023
0024
0025
0026
0027
0028
0029 #ifndef TMVA_DNN_RNN_LAYER
0030 #define TMVA_DNN_RNN_LAYER
0031
0032 #include <cmath>
0033 #include <iostream>
0034 #include <vector>
0035 #include <string>
0036
0037 #include "TMatrix.h"
0038 #include "TMVA/DNN/Functions.h"
0039
0040 namespace TMVA
0041 {
0042 namespace DNN
0043 {
0044
0045 namespace RNN {
0046
0047
0048
0049
0050
0051
0052
0053
0054
0055 template<typename Architecture_t>
0056 class TBasicRNNLayer : public VGeneralLayer<Architecture_t>
0057 {
0058
0059 public:
0060
0061 using Tensor_t = typename Architecture_t::Tensor_t;
0062 using Matrix_t = typename Architecture_t::Matrix_t;
0063 using Scalar_t = typename Architecture_t::Scalar_t;
0064
0065 using LayerDescriptor_t = typename Architecture_t::RecurrentDescriptor_t;
0066 using WeightsDescriptor_t = typename Architecture_t::FilterDescriptor_t;
0067 using TensorDescriptor_t = typename Architecture_t::TensorDescriptor_t;
0068 using HelperDescriptor_t = typename Architecture_t::DropoutDescriptor_t;
0069
0070 using RNNWorkspace_t = typename Architecture_t::RNNWorkspace_t;
0071 using RNNDescriptors_t = typename Architecture_t::RNNDescriptors_t;
0072
0073 private:
0074
0075 size_t fTimeSteps;
0076 size_t fStateSize;
0077 bool fRememberState;
0078 bool fReturnSequence = false;
0079
0080 DNN::EActivationFunction fF;
0081
0082 Matrix_t fState;
0083 Matrix_t &fWeightsInput;
0084 Matrix_t &fWeightsState;
0085 Matrix_t &fBiases;
0086
0087 Tensor_t fDerivatives;
0088 Matrix_t &fWeightInputGradients;
0089 Matrix_t &fWeightStateGradients;
0090 Matrix_t &fBiasGradients;
0091
0092 Tensor_t fWeightsTensor;
0093 Tensor_t fWeightGradientsTensor;
0094
0095 typename Architecture_t::ActivationDescriptor_t fActivationDesc;
0096
0097 TDescriptors *fDescriptors = nullptr;
0098 TWorkspace *fWorkspace = nullptr;
0099
0100 Matrix_t fCell;
0101
0102
0103 Tensor_t fX;
0104 Tensor_t fY;
0105 Tensor_t fDx;
0106 Tensor_t fDy;
0107
0108
0109 public:
0110
0111
0112 TBasicRNNLayer(size_t batchSize, size_t stateSize, size_t inputSize,
0113 size_t timeSteps, bool rememberState = false, bool returnSequence = false,
0114 DNN::EActivationFunction f = DNN::EActivationFunction::kTanh,
0115 bool training = true, DNN::EInitialization fA = DNN::EInitialization::kZero);
0116
0117
0118 TBasicRNNLayer(const TBasicRNNLayer &);
0119
0120
0121 virtual ~TBasicRNNLayer();
0122
0123
0124
0125 virtual void Initialize();
0126
0127
0128
0129 void InitState(DNN::EInitialization m = DNN::EInitialization::kZero);
0130
0131
0132
0133 void Forward(Tensor_t &input, bool isTraining = true);
0134
0135
0136 void CellForward(const Matrix_t &input, Matrix_t & dF);
0137
0138
0139
0140 void Backward(Tensor_t &gradients_backward,
0141 const Tensor_t &activations_backward);
0142
0143
0144 void Update(const Scalar_t learningRate);
0145
0146
0147
0148 inline Matrix_t & CellBackward(Matrix_t & state_gradients_backward,
0149 const Matrix_t & precStateActivations,
0150 const Matrix_t & input, Matrix_t & input_gradient, Matrix_t &dF);
0151
0152
0153 void Print() const;
0154
0155
0156 virtual void AddWeightsXMLTo(void *parent);
0157
0158
0159 virtual void ReadWeightsFromXML(void *parent);
0160
0161 void InitTensors();
0162
0163
0164
0165
0166
0167
0168 size_t GetTimeSteps() const { return fTimeSteps; }
0169 size_t GetStateSize() const { return fStateSize; }
0170 size_t GetInputSize() const { return this->GetInputWidth(); }
0171 inline bool DoesRememberState() const {return fRememberState;}
0172 inline bool DoesReturnSequence() const { return fReturnSequence; }
0173 inline DNN::EActivationFunction GetActivationFunction() const {return fF;}
0174 Matrix_t & GetState() {return fState;}
0175 const Matrix_t & GetState() const {return fState;}
0176 Matrix_t &GetCell() { return fCell; }
0177 const Matrix_t &GetCell() const { return fCell; }
0178
0179 Matrix_t & GetWeightsInput() {return fWeightsInput;}
0180 const Matrix_t & GetWeightsInput() const {return fWeightsInput;}
0181 Matrix_t & GetWeightsState() {return fWeightsState;}
0182 const Matrix_t & GetWeightsState() const {return fWeightsState;}
0183 Tensor_t & GetDerivatives() {return fDerivatives;}
0184 const Tensor_t & GetDerivatives() const {return fDerivatives;}
0185
0186
0187
0188 Matrix_t & GetBiasesState() {return fBiases;}
0189 const Matrix_t & GetBiasesState() const {return fBiases;}
0190 Matrix_t & GetBiasStateGradients() {return fBiasGradients;}
0191 const Matrix_t & GetBiasStateGradients() const {return fBiasGradients;}
0192 Matrix_t & GetWeightInputGradients() {return fWeightInputGradients;}
0193 const Matrix_t & GetWeightInputGradients() const {return fWeightInputGradients;}
0194 Matrix_t & GetWeightStateGradients() {return fWeightStateGradients;}
0195 const Matrix_t & GetWeightStateGradients() const {return fWeightStateGradients;}
0196
0197 Tensor_t &GetWeightsTensor() { return fWeightsTensor; }
0198 const Tensor_t &GetWeightsTensor() const { return fWeightsTensor; }
0199 Tensor_t &GetWeightGradientsTensor() { return fWeightGradientsTensor; }
0200 const Tensor_t &GetWeightGradientsTensor() const { return fWeightGradientsTensor; }
0201
0202 Tensor_t &GetX() { return fX; }
0203 Tensor_t &GetY() { return fY; }
0204 Tensor_t &GetDX() { return fDx; }
0205 Tensor_t &GetDY() { return fDy; }
0206 };
0207
0208
0209
0210
0211
0212 template <typename Architecture_t>
0213 TBasicRNNLayer<Architecture_t>::TBasicRNNLayer(size_t batchSize, size_t stateSize, size_t inputSize, size_t timeSteps,
0214 bool rememberState, bool returnSequence, DNN::EActivationFunction f, bool ,
0215 DNN::EInitialization fA)
0216
0217 : VGeneralLayer<Architecture_t>(batchSize, 1, timeSteps, inputSize, 1, (returnSequence) ? timeSteps : 1 ,
0218 stateSize, 2, {stateSize, stateSize}, {inputSize, stateSize}, 1, {stateSize}, {1},
0219 batchSize, (returnSequence) ? timeSteps : 1, stateSize, fA),
0220 fTimeSteps(timeSteps), fStateSize(stateSize), fRememberState(rememberState), fReturnSequence(returnSequence), fF(f), fState(batchSize, stateSize),
0221 fWeightsInput(this->GetWeightsAt(0)), fWeightsState(this->GetWeightsAt(1)),
0222 fBiases(this->GetBiasesAt(0)), fDerivatives(timeSteps, batchSize, stateSize),
0223 fWeightInputGradients(this->GetWeightGradientsAt(0)), fWeightStateGradients(this->GetWeightGradientsAt(1)),
0224 fBiasGradients(this->GetBiasGradientsAt(0)), fWeightsTensor({0}), fWeightGradientsTensor({0})
0225 {
0226 InitTensors();
0227 }
0228
0229
0230 template <typename Architecture_t>
0231 TBasicRNNLayer<Architecture_t>::TBasicRNNLayer(const TBasicRNNLayer &layer)
0232 : VGeneralLayer<Architecture_t>(layer), fTimeSteps(layer.fTimeSteps), fStateSize(layer.fStateSize),
0233 fRememberState(layer.fRememberState), fReturnSequence(layer.fReturnSequence), fF(layer.GetActivationFunction()),
0234 fState(layer.GetBatchSize(), layer.GetStateSize()),
0235 fWeightsInput(this->GetWeightsAt(0)), fWeightsState(this->GetWeightsAt(1)), fBiases(this->GetBiasesAt(0)),
0236 fDerivatives(layer.GetDerivatives().GetShape()), fWeightInputGradients(this->GetWeightGradientsAt(0)),
0237 fWeightStateGradients(this->GetWeightGradientsAt(1)), fBiasGradients(this->GetBiasGradientsAt(0)),
0238 fWeightsTensor({0}), fWeightGradientsTensor({0})
0239 {
0240
0241 Architecture_t::Copy(fDerivatives, layer.GetDerivatives() );
0242
0243
0244 Architecture_t::Copy(fState, layer.GetState());
0245 InitTensors();
0246 }
0247
0248 template <typename Architecture_t>
0249 TBasicRNNLayer<Architecture_t>::~TBasicRNNLayer()
0250 {
0251 if (fDescriptors) {
0252 Architecture_t::ReleaseRNNDescriptors(fDescriptors);
0253 delete fDescriptors;
0254 }
0255
0256 if (fWorkspace) {
0257 Architecture_t::FreeRNNWorkspace(fWorkspace);
0258 delete fWorkspace;
0259 }
0260 }
0261
0262
0263 template<typename Architecture_t>
0264 void TBasicRNNLayer<Architecture_t>::Initialize()
0265 {
0266
0267
0268
0269
0270
0271 VGeneralLayer<Architecture_t>::Initialize();
0272
0273 Architecture_t::InitializeRNNDescriptors(fDescriptors, this);
0274 Architecture_t::InitializeRNNWorkspace(fWorkspace, fDescriptors, this);
0275 }
0276
0277
0278 template <typename Architecture_t>
0279 void TBasicRNNLayer<Architecture_t>::InitTensors()
0280 {
0281
0282 Architecture_t::InitializeRNNTensors(this);
0283 }
0284
0285 template <typename Architecture_t>
0286 auto TBasicRNNLayer<Architecture_t>::InitState(DNN::EInitialization ) -> void
0287 {
0288 DNN::initialize<Architecture_t>(this->GetState(), DNN::EInitialization::kZero);
0289
0290 Architecture_t::InitializeActivationDescriptor(fActivationDesc,this->GetActivationFunction());
0291 }
0292
0293
0294 template<typename Architecture_t>
0295 auto TBasicRNNLayer<Architecture_t>::Print() const
0296 -> void
0297 {
0298 std::cout << " RECURRENT Layer: \t ";
0299 std::cout << " (NInput = " << this->GetInputSize();
0300 std::cout << ", NState = " << this->GetStateSize();
0301 std::cout << ", NTime = " << this->GetTimeSteps() << " )";
0302 std::cout << "\tOutput = ( " << this->GetOutput().GetFirstSize() << " , " << this->GetOutput().GetHSize() << " , " << this->GetOutput().GetWSize() << " )\n";
0303 }
0304
0305 template <typename Architecture_t>
0306 auto debugMatrix(const typename Architecture_t::Matrix_t &A, const std::string name = "matrix")
0307 -> void
0308 {
0309 std::cout << name << "\n";
0310 for (size_t i = 0; i < A.GetNrows(); ++i) {
0311 for (size_t j = 0; j < A.GetNcols(); ++j) {
0312 std::cout << A(i, j) << " ";
0313 }
0314 std::cout << "\n";
0315 }
0316 std::cout << "********\n";
0317 }
0318
0319
0320
0321 template <typename Architecture_t>
0322 void TBasicRNNLayer<Architecture_t>::Forward(Tensor_t &input, bool isTraining )
0323 {
0324
0325
0326
0327 if (Architecture_t::IsCudnn()) {
0328
0329 Tensor_t &x = this->fX;
0330 Tensor_t &y = this->fY;
0331
0332 Architecture_t::Rearrange(x, input);
0333
0334
0335
0336
0337 const auto & weights = this->GetWeightsTensor();
0338
0339
0340
0341
0342
0343 auto &hx = this->GetState();
0344 auto &cx = this->GetCell();
0345
0346 auto &hy = this->GetState();
0347 auto &cy = this->GetCell();
0348
0349 auto & rnnDesc = static_cast<RNNDescriptors_t &>(*fDescriptors);
0350 auto & rnnWork = static_cast<RNNWorkspace_t &>(*fWorkspace);
0351
0352
0353
0354 Architecture_t::RNNForward(x, hx, cx, weights, y, hy, cy, rnnDesc, rnnWork, isTraining);
0355
0356 if (fReturnSequence) {
0357 Architecture_t::Rearrange(this->GetOutput(), y);
0358 }
0359 else {
0360
0361 Tensor_t tmp = (y.At(y.GetShape()[0] - 1)).Reshape({y.GetShape()[1], 1, y.GetShape()[2]});
0362 Architecture_t::Copy(this->GetOutput(), tmp);
0363 }
0364 return;
0365 }
0366
0367
0368
0369
0370
0371
0372
0373 Tensor_t arrInput (fTimeSteps, this->GetBatchSize(), this->GetInputWidth() );
0374
0375 Architecture_t::Rearrange(arrInput, input);
0376 Tensor_t arrOutput ( fTimeSteps, this->GetBatchSize(), fStateSize);
0377
0378
0379 if (!this->fRememberState) InitState(DNN::EInitialization::kZero);
0380
0381 for (size_t t = 0; t < fTimeSteps; ++t) {
0382 Matrix_t arrInput_m = arrInput.At(t).GetMatrix();
0383 Matrix_t df_m = fDerivatives.At(t).GetMatrix();
0384 CellForward(arrInput_m, df_m );
0385 Matrix_t arrOutput_m = arrOutput.At(t).GetMatrix();
0386 Architecture_t::Copy(arrOutput_m, fState);
0387 }
0388
0389 if (fReturnSequence)
0390 Architecture_t::Rearrange(this->GetOutput(), arrOutput);
0391 else {
0392
0393
0394 Tensor_t tmp = arrOutput.At(fTimeSteps - 1);
0395
0396
0397 tmp = tmp.Reshape({tmp.GetShape()[0], tmp.GetShape()[1], 1});
0398 assert(tmp.GetSize() == this->GetOutput().GetSize());
0399 assert(tmp.GetShape()[0] == this->GetOutput().GetShape()[2]);
0400 Architecture_t::Rearrange(this->GetOutput(), tmp);
0401
0402 fY = arrOutput;
0403 }
0404 }
0405
0406
0407 template <typename Architecture_t>
0408 auto inline TBasicRNNLayer<Architecture_t>::CellForward(const Matrix_t &input, Matrix_t &dF)
0409 -> void
0410 {
0411
0412 const DNN::EActivationFunction fAF = this->GetActivationFunction();
0413 Matrix_t tmpState(fState.GetNrows(), fState.GetNcols());
0414 Architecture_t::MultiplyTranspose(tmpState, fState, fWeightsState);
0415 Architecture_t::MultiplyTranspose(fState, input, fWeightsInput);
0416 Architecture_t::ScaleAdd(fState, tmpState);
0417 Architecture_t::AddRowWise(fState, fBiases);
0418 Tensor_t inputActivFunc(dF);
0419 Tensor_t tState(fState);
0420
0421
0422
0423
0424 Architecture_t::Copy(inputActivFunc, tState);
0425 Architecture_t::ActivationFunctionForward(tState, fAF, fActivationDesc);
0426
0427 }
0428
0429
0430 template <typename Architecture_t>
0431 auto inline TBasicRNNLayer<Architecture_t>::Backward(Tensor_t &gradients_backward,
0432 const Tensor_t &activations_backward) -> void
0433
0434
0435 {
0436
0437 if (Architecture_t::IsCudnn() ) {
0438
0439 Tensor_t &x = this->fX;
0440 Tensor_t &y = this->fY;
0441 Tensor_t &dx = this->fDx;
0442 Tensor_t &dy = this->fDy;
0443
0444
0445 assert(activations_backward.GetStrides()[1] == this->GetInputSize() );
0446
0447 Architecture_t::Rearrange(x, activations_backward);
0448
0449 if (!fReturnSequence) {
0450
0451
0452 Architecture_t::InitializeZero(dy);
0453
0454
0455 Tensor_t tmp2 = dy.At(dy.GetShape()[0] - 1).Reshape({dy.GetShape()[1], 1, dy.GetShape()[2]});
0456
0457
0458 Architecture_t::Copy(tmp2, this->GetActivationGradients());
0459 }
0460 else {
0461 Architecture_t::Rearrange(y, this->GetOutput());
0462 Architecture_t::Rearrange(dy, this->GetActivationGradients());
0463 }
0464
0465
0466
0467
0468
0469 auto &weights = this->GetWeightsTensor();
0470 auto &weightGradients = this->GetWeightGradientsTensor();
0471
0472
0473 Architecture_t::InitializeZero(weightGradients);
0474
0475
0476
0477 auto &hx = this->GetState();
0478 auto &cx = this->GetCell();
0479
0480 auto &dhy = hx;
0481 auto &dcy = cx;
0482 auto &dhx = hx;
0483 auto &dcx = cx;
0484
0485
0486 auto & rnnDesc = static_cast<RNNDescriptors_t &>(*fDescriptors);
0487 auto & rnnWork = static_cast<RNNWorkspace_t &>(*fWorkspace);
0488
0489 Architecture_t::RNNBackward(x, hx, cx, y, dy, dhy, dcy, weights, dx, dhx, dcx, weightGradients, rnnDesc, rnnWork);
0490
0491 if (gradients_backward.GetSize() != 0)
0492 Architecture_t::Rearrange(gradients_backward, dx);
0493
0494 return;
0495 }
0496
0497
0498
0499
0500
0501
0502
0503
0504 bool dummy = false;
0505 if (gradients_backward.GetSize() == 0) {
0506 dummy = true;
0507 }
0508 Tensor_t arr_gradients_backward ( fTimeSteps, this->GetBatchSize(), this->GetInputSize());
0509
0510
0511 if (!dummy) {
0512
0513
0514 }
0515 Tensor_t arr_activations_backward ( fTimeSteps, this->GetBatchSize(), this->GetInputSize());
0516
0517 Architecture_t::Rearrange(arr_activations_backward, activations_backward);
0518
0519 Matrix_t state_gradients_backward(this->GetBatchSize(), fStateSize);
0520 DNN::initialize<Architecture_t>(state_gradients_backward, DNN::EInitialization::kZero);
0521
0522 Matrix_t initState(this->GetBatchSize(), fStateSize);
0523 DNN::initialize<Architecture_t>(initState, DNN::EInitialization::kZero);
0524
0525 Tensor_t arr_output ( fTimeSteps, this->GetBatchSize(), fStateSize);
0526 Tensor_t arr_actgradients(fTimeSteps, this->GetBatchSize(), fStateSize);
0527
0528 if (fReturnSequence) {
0529 Architecture_t::Rearrange(arr_output, this->GetOutput());
0530 Architecture_t::Rearrange(arr_actgradients, this->GetActivationGradients());
0531 } else {
0532
0533 arr_output = fY;
0534
0535 Architecture_t::InitializeZero(arr_actgradients);
0536
0537 Tensor_t tmp_grad = arr_actgradients.At(fTimeSteps - 1).Reshape({this->GetBatchSize(), fStateSize, 1});
0538 assert(tmp_grad.GetSize() == this->GetActivationGradients().GetSize());
0539 assert(tmp_grad.GetShape()[0] ==
0540 this->GetActivationGradients().GetShape()[2]);
0541
0542 Architecture_t::Rearrange(tmp_grad, this->GetActivationGradients());
0543 }
0544
0545
0546 fWeightInputGradients.Zero();
0547 fWeightStateGradients.Zero();
0548 fBiasGradients.Zero();
0549
0550 for (size_t t = fTimeSteps; t > 0; t--) {
0551
0552 Matrix_t actgrad_m = arr_actgradients.At(t - 1).GetMatrix();
0553 Architecture_t::ScaleAdd(state_gradients_backward, actgrad_m);
0554
0555 Matrix_t actbw_m = arr_activations_backward.At(t - 1).GetMatrix();
0556 Matrix_t gradbw_m = arr_gradients_backward.At(t - 1).GetMatrix();
0557
0558
0559 Tensor_t df = fDerivatives.At(t-1);
0560 Tensor_t dy = Tensor_t(state_gradients_backward);
0561
0562 Tensor_t y = arr_output.At(t-1);
0563 Architecture_t::ActivationFunctionBackward(df, y,
0564 dy, df,
0565 this->GetActivationFunction(), fActivationDesc);
0566
0567 Matrix_t df_m = df.GetMatrix();
0568
0569
0570 if (t > 1) {
0571 Matrix_t precStateActivations = arr_output.At(t - 2).GetMatrix();
0572 CellBackward(state_gradients_backward, precStateActivations, actbw_m, gradbw_m, df_m);
0573
0574 } else {
0575 const Matrix_t & precStateActivations = initState;
0576 CellBackward(state_gradients_backward, precStateActivations, actbw_m, gradbw_m, df_m);
0577
0578 }
0579 }
0580 if (!dummy) {
0581 Architecture_t::Rearrange(gradients_backward, arr_gradients_backward );
0582 }
0583 }
0584
0585
0586 template <typename Architecture_t>
0587 auto inline TBasicRNNLayer<Architecture_t>::CellBackward(Matrix_t & state_gradients_backward,
0588 const Matrix_t & precStateActivations,
0589 const Matrix_t & input, Matrix_t & input_gradient, Matrix_t &dF)
0590 -> Matrix_t &
0591 {
0592 return Architecture_t::RecurrentLayerBackward(state_gradients_backward, fWeightInputGradients, fWeightStateGradients,
0593 fBiasGradients, dF, precStateActivations, fWeightsInput,
0594 fWeightsState, input, input_gradient);
0595 }
0596
0597
0598 template <typename Architecture_t>
0599 void TBasicRNNLayer<Architecture_t>::AddWeightsXMLTo(void *parent)
0600 {
0601 auto layerxml = gTools().xmlengine().NewChild(parent, nullptr, "RNNLayer");
0602
0603
0604 gTools().xmlengine().NewAttr(layerxml, nullptr, "StateSize", gTools().StringFromInt(this->GetStateSize()));
0605 gTools().xmlengine().NewAttr(layerxml, nullptr, "InputSize", gTools().StringFromInt(this->GetInputSize()));
0606 gTools().xmlengine().NewAttr(layerxml, nullptr, "TimeSteps", gTools().StringFromInt(this->GetTimeSteps()));
0607 gTools().xmlengine().NewAttr(layerxml, nullptr, "RememberState", gTools().StringFromInt(this->DoesRememberState()));
0608 gTools().xmlengine().NewAttr(layerxml, nullptr, "ReturnSequence", gTools().StringFromInt(this->DoesReturnSequence()));
0609
0610
0611 this->WriteMatrixToXML(layerxml, "InputWeights", this -> GetWeightsAt(0));
0612 this->WriteMatrixToXML(layerxml, "StateWeights", this -> GetWeightsAt(1));
0613 this->WriteMatrixToXML(layerxml, "Biases", this -> GetBiasesAt(0));
0614
0615
0616 }
0617
0618
0619 template <typename Architecture_t>
0620 void TBasicRNNLayer<Architecture_t>::ReadWeightsFromXML(void *parent)
0621 {
0622
0623 this->ReadMatrixXML(parent,"InputWeights", this -> GetWeightsAt(0));
0624 this->ReadMatrixXML(parent,"StateWeights", this -> GetWeightsAt(1));
0625 this->ReadMatrixXML(parent,"Biases", this -> GetBiasesAt(0));
0626
0627 }
0628
0629 }
0630 }
0631 }
0632
0633 #endif