File indexing completed on 2025-01-18 10:10:54
0001
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021
0022
0023
0024
0025
0026
0027
0028
0029
0030 #ifndef TMVA_DNN_GRU_LAYER
0031 #define TMVA_DNN_GRU_LAYER
0032
0033 #include <cmath>
0034 #include <iostream>
0035 #include <vector>
0036
0037 #include "TMatrix.h"
0038 #include "TMVA/DNN/Functions.h"
0039
0040 namespace TMVA
0041 {
0042 namespace DNN
0043 {
0044 namespace RNN
0045 {
0046
0047
0048
0049
0050
0051
0052
0053
0054
0055 template<typename Architecture_t>
0056 class TBasicGRULayer : public VGeneralLayer<Architecture_t>
0057 {
0058
0059 public:
0060
0061 using Matrix_t = typename Architecture_t::Matrix_t;
0062 using Scalar_t = typename Architecture_t::Scalar_t;
0063 using Tensor_t = typename Architecture_t::Tensor_t;
0064
0065 using LayerDescriptor_t = typename Architecture_t::RecurrentDescriptor_t;
0066 using WeightsDescriptor_t = typename Architecture_t::FilterDescriptor_t;
0067 using TensorDescriptor_t = typename Architecture_t::TensorDescriptor_t;
0068 using HelperDescriptor_t = typename Architecture_t::DropoutDescriptor_t;
0069
0070 using RNNWorkspace_t = typename Architecture_t::RNNWorkspace_t;
0071 using RNNDescriptors_t = typename Architecture_t::RNNDescriptors_t;
0072
0073 private:
0074
0075 size_t fStateSize;
0076 size_t fTimeSteps;
0077
0078 bool fRememberState;
0079 bool fReturnSequence = false;
0080 bool fResetGateAfter = false;
0081
0082 DNN::EActivationFunction fF1;
0083 DNN::EActivationFunction fF2;
0084
0085 Matrix_t fResetValue;
0086 Matrix_t fUpdateValue;
0087 Matrix_t fCandidateValue;
0088 Matrix_t fState;
0089
0090
0091 Matrix_t &fWeightsResetGate;
0092 Matrix_t &fWeightsResetGateState;
0093 Matrix_t &fResetGateBias;
0094
0095 Matrix_t &fWeightsUpdateGate;
0096 Matrix_t &fWeightsUpdateGateState;
0097 Matrix_t &fUpdateGateBias;
0098
0099 Matrix_t &fWeightsCandidate;
0100 Matrix_t &fWeightsCandidateState;
0101 Matrix_t &fCandidateBias;
0102
0103
0104 std::vector<Matrix_t> reset_gate_value;
0105 std::vector<Matrix_t> update_gate_value;
0106 std::vector<Matrix_t> candidate_gate_value;
0107
0108 std::vector<Matrix_t> fDerivativesReset;
0109 std::vector<Matrix_t> fDerivativesUpdate;
0110 std::vector<Matrix_t> fDerivativesCandidate;
0111
0112 Matrix_t &fWeightsResetGradients;
0113 Matrix_t &fWeightsResetStateGradients;
0114 Matrix_t &fResetBiasGradients;
0115 Matrix_t &fWeightsUpdateGradients;
0116 Matrix_t &fWeightsUpdateStateGradients;
0117 Matrix_t &fUpdateBiasGradients;
0118 Matrix_t &fWeightsCandidateGradients;
0119 Matrix_t &fWeightsCandidateStateGradients;
0120 Matrix_t &fCandidateBiasGradients;
0121
0122 Matrix_t fCell;
0123
0124
0125 Tensor_t fWeightsTensor;
0126 Tensor_t fWeightGradientsTensor;
0127
0128
0129 Tensor_t fX;
0130 Tensor_t fY;
0131 Tensor_t fDx;
0132 Tensor_t fDy;
0133
0134 TDescriptors *fDescriptors = nullptr;
0135 TWorkspace *fWorkspace = nullptr;
0136
0137 public:
0138
0139
0140 TBasicGRULayer(size_t batchSize, size_t stateSize, size_t inputSize,
0141 size_t timeSteps, bool rememberState = false, bool returnSequence = false,
0142 bool resetGateAfter = false,
0143 DNN::EActivationFunction f1 = DNN::EActivationFunction::kSigmoid,
0144 DNN::EActivationFunction f2 = DNN::EActivationFunction::kTanh,
0145 bool training = true, DNN::EInitialization fA = DNN::EInitialization::kZero);
0146
0147
0148 TBasicGRULayer(const TBasicGRULayer &);
0149
0150
0151
0152 virtual void Initialize();
0153
0154
0155 void InitState(DNN::EInitialization m = DNN::EInitialization::kZero);
0156
0157
0158
0159 void Forward(Tensor_t &input, bool isTraining = true);
0160
0161
0162 void CellForward(Matrix_t &updateGateValues, Matrix_t &candidateValues);
0163
0164
0165
0166 void Backward(Tensor_t &gradients_backward,
0167 const Tensor_t &activations_backward);
0168
0169
0170 void Update(const Scalar_t learningRate);
0171
0172
0173
0174 Matrix_t & CellBackward(Matrix_t & state_gradients_backward,
0175 const Matrix_t & precStateActivations,
0176 const Matrix_t & reset_gate, const Matrix_t & update_gate,
0177 const Matrix_t & candidate_gate,
0178 const Matrix_t & input, Matrix_t & input_gradient,
0179 Matrix_t &dr, Matrix_t &du, Matrix_t &dc);
0180
0181
0182 void ResetGate(const Matrix_t &input, Matrix_t &di);
0183
0184
0185 void UpdateGate(const Matrix_t &input, Matrix_t &df);
0186
0187
0188 void CandidateValue(const Matrix_t &input, Matrix_t &dc);
0189
0190
0191 void Print() const;
0192
0193
0194 void AddWeightsXMLTo(void *parent);
0195
0196
0197 void ReadWeightsFromXML(void *parent);
0198
0199
0200 size_t GetInputSize() const { return this->GetInputWidth(); }
0201 size_t GetTimeSteps() const { return fTimeSteps; }
0202 size_t GetStateSize() const { return fStateSize; }
0203
0204 inline bool DoesRememberState() const { return fRememberState; }
0205 inline bool DoesReturnSequence() const { return fReturnSequence; }
0206
0207 inline DNN::EActivationFunction GetActivationFunctionF1() const { return fF1; }
0208 inline DNN::EActivationFunction GetActivationFunctionF2() const { return fF2; }
0209
0210 const Matrix_t & GetResetGateValue() const { return fResetValue; }
0211 Matrix_t & GetResetGateValue() { return fResetValue; }
0212 const Matrix_t & GetCandidateValue() const { return fCandidateValue; }
0213 Matrix_t & GetCandidateValue() { return fCandidateValue; }
0214 const Matrix_t & GetUpdateGateValue() const { return fUpdateValue; }
0215 Matrix_t & GetUpdateGateValue() { return fUpdateValue; }
0216
0217 const Matrix_t & GetState() const { return fState; }
0218 Matrix_t & GetState() { return fState; }
0219 const Matrix_t &GetCell() const { return fCell; }
0220 Matrix_t & GetCell() { return fCell; }
0221
0222 const Matrix_t & GetWeightsResetGate() const { return fWeightsResetGate; }
0223 Matrix_t & GetWeightsResetGate() { return fWeightsResetGate; }
0224 const Matrix_t & GetWeightsCandidate() const { return fWeightsCandidate; }
0225 Matrix_t & GetWeightsCandidate() { return fWeightsCandidate; }
0226 const Matrix_t & GetWeightsUpdateGate() const { return fWeightsUpdateGate; }
0227 Matrix_t & GetWeightsUpdateGate() { return fWeightsUpdateGate; }
0228
0229 const Matrix_t & GetWeightsResetGateState() const { return fWeightsResetGateState; }
0230 Matrix_t & GetWeightsResetGateState() { return fWeightsResetGateState; }
0231 const Matrix_t & GetWeightsUpdateGateState() const { return fWeightsUpdateGateState; }
0232 Matrix_t & GetWeightsUpdateGateState() { return fWeightsUpdateGateState; }
0233 const Matrix_t & GetWeightsCandidateState() const { return fWeightsCandidateState; }
0234 Matrix_t & GetWeightsCandidateState() { return fWeightsCandidateState; }
0235
0236 const std::vector<Matrix_t> & GetDerivativesReset() const { return fDerivativesReset; }
0237 std::vector<Matrix_t> & GetDerivativesReset() { return fDerivativesReset; }
0238 const Matrix_t & GetResetDerivativesAt(size_t i) const { return fDerivativesReset[i]; }
0239 Matrix_t & GetResetDerivativesAt(size_t i) { return fDerivativesReset[i]; }
0240 const std::vector<Matrix_t> & GetDerivativesUpdate() const { return fDerivativesUpdate; }
0241 std::vector<Matrix_t> & GetDerivativesUpdate() { return fDerivativesUpdate; }
0242 const Matrix_t & GetUpdateDerivativesAt(size_t i) const { return fDerivativesUpdate[i]; }
0243 Matrix_t & GetUpdateDerivativesAt(size_t i) { return fDerivativesUpdate[i]; }
0244 const std::vector<Matrix_t> & GetDerivativesCandidate() const { return fDerivativesCandidate; }
0245 std::vector<Matrix_t> & GetDerivativesCandidate() { return fDerivativesCandidate; }
0246 const Matrix_t & GetCandidateDerivativesAt(size_t i) const { return fDerivativesCandidate[i]; }
0247 Matrix_t & GetCandidateDerivativesAt(size_t i) { return fDerivativesCandidate[i]; }
0248
0249 const std::vector<Matrix_t> & GetResetGateTensor() const { return reset_gate_value; }
0250 std::vector<Matrix_t> & GetResetGateTensor() { return reset_gate_value; }
0251 const Matrix_t & GetResetGateTensorAt(size_t i) const { return reset_gate_value[i]; }
0252 Matrix_t & GetResetGateTensorAt(size_t i) { return reset_gate_value[i]; }
0253 const std::vector<Matrix_t> & GetUpdateGateTensor() const { return update_gate_value; }
0254 std::vector<Matrix_t> & GetUpdateGateTensor() { return update_gate_value; }
0255 const Matrix_t & GetUpdateGateTensorAt(size_t i) const { return update_gate_value[i]; }
0256 Matrix_t & GetUpdateGateTensorAt(size_t i) { return update_gate_value[i]; }
0257 const std::vector<Matrix_t> & GetCandidateGateTensor() const { return candidate_gate_value; }
0258 std::vector<Matrix_t> & GetCandidateGateTensor() { return candidate_gate_value; }
0259 const Matrix_t & GetCandidateGateTensorAt(size_t i) const { return candidate_gate_value[i]; }
0260 Matrix_t & GetCandidateGateTensorAt(size_t i) { return candidate_gate_value[i]; }
0261
0262
0263
0264 const Matrix_t & GetResetGateBias() const { return fResetGateBias; }
0265 Matrix_t & GetResetGateBias() { return fResetGateBias; }
0266 const Matrix_t & GetUpdateGateBias() const { return fUpdateGateBias; }
0267 Matrix_t & GetUpdateGateBias() { return fUpdateGateBias; }
0268 const Matrix_t & GetCandidateBias() const { return fCandidateBias; }
0269 Matrix_t & GetCandidateBias() { return fCandidateBias; }
0270
0271 const Matrix_t & GetWeightsResetGradients() const { return fWeightsResetGradients; }
0272 Matrix_t & GetWeightsResetGradients() { return fWeightsResetGradients; }
0273 const Matrix_t & GetWeightsResetStateGradients() const { return fWeightsResetStateGradients; }
0274 Matrix_t & GetWeightsResetStateGradients() { return fWeightsResetStateGradients; }
0275 const Matrix_t & GetResetBiasGradients() const { return fResetBiasGradients; }
0276 Matrix_t & GetResetBiasGradients() { return fResetBiasGradients; }
0277 const Matrix_t & GetWeightsUpdateGradients() const { return fWeightsUpdateGradients; }
0278 Matrix_t & GetWeightsUpdateGradients() { return fWeightsUpdateGradients; }
0279 const Matrix_t & GetWeigthsUpdateStateGradients() const { return fWeightsUpdateStateGradients; }
0280 Matrix_t & GetWeightsUpdateStateGradients() { return fWeightsUpdateStateGradients; }
0281 const Matrix_t & GetUpdateBiasGradients() const { return fUpdateBiasGradients; }
0282 Matrix_t & GetUpdateBiasGradients() { return fUpdateBiasGradients; }
0283 const Matrix_t & GetWeightsCandidateGradients() const { return fWeightsCandidateGradients; }
0284 Matrix_t & GetWeightsCandidateGradients() { return fWeightsCandidateGradients; }
0285 const Matrix_t & GetWeightsCandidateStateGradients() const { return fWeightsCandidateStateGradients; }
0286 Matrix_t & GetWeightsCandidateStateGradients() { return fWeightsCandidateStateGradients; }
0287 const Matrix_t & GetCandidateBiasGradients() const { return fCandidateBiasGradients; }
0288 Matrix_t & GetCandidateBiasGradients() { return fCandidateBiasGradients; }
0289
0290 Tensor_t &GetWeightsTensor() { return fWeightsTensor; }
0291 const Tensor_t &GetWeightsTensor() const { return fWeightsTensor; }
0292 Tensor_t &GetWeightGradientsTensor() { return fWeightGradientsTensor; }
0293 const Tensor_t &GetWeightGradientsTensor() const { return fWeightGradientsTensor; }
0294
0295 Tensor_t &GetX() { return fX; }
0296 Tensor_t &GetY() { return fY; }
0297 Tensor_t &GetDX() { return fDx; }
0298 Tensor_t &GetDY() { return fDy; }
0299 };
0300
0301
0302
0303
0304
0305
0306
0307 template <typename Architecture_t>
0308 TBasicGRULayer<Architecture_t>::TBasicGRULayer(size_t batchSize, size_t stateSize, size_t inputSize, size_t timeSteps,
0309 bool rememberState, bool returnSequence, bool resetGateAfter, DNN::EActivationFunction f1,
0310 DNN::EActivationFunction f2, bool ,
0311 DNN::EInitialization fA)
0312 : VGeneralLayer<Architecture_t>(batchSize, 1, timeSteps, inputSize, 1, (returnSequence) ? timeSteps : 1, stateSize,
0313 6, {stateSize, stateSize, stateSize, stateSize, stateSize, stateSize},
0314 {inputSize, inputSize, inputSize, stateSize, stateSize, stateSize}, 3,
0315 {stateSize, stateSize, stateSize}, {1, 1, 1}, batchSize,
0316 (returnSequence) ? timeSteps : 1, stateSize, fA),
0317 fStateSize(stateSize), fTimeSteps(timeSteps), fRememberState(rememberState), fReturnSequence(returnSequence), fResetGateAfter(resetGateAfter),
0318 fF1(f1), fF2(f2), fResetValue(batchSize, stateSize), fUpdateValue(batchSize, stateSize),
0319 fCandidateValue(batchSize, stateSize), fState(batchSize, stateSize), fWeightsResetGate(this->GetWeightsAt(0)),
0320 fWeightsResetGateState(this->GetWeightsAt(3)), fResetGateBias(this->GetBiasesAt(0)),
0321 fWeightsUpdateGate(this->GetWeightsAt(1)), fWeightsUpdateGateState(this->GetWeightsAt(4)),
0322 fUpdateGateBias(this->GetBiasesAt(1)), fWeightsCandidate(this->GetWeightsAt(2)),
0323 fWeightsCandidateState(this->GetWeightsAt(5)), fCandidateBias(this->GetBiasesAt(2)),
0324 fWeightsResetGradients(this->GetWeightGradientsAt(0)), fWeightsResetStateGradients(this->GetWeightGradientsAt(3)),
0325 fResetBiasGradients(this->GetBiasGradientsAt(0)), fWeightsUpdateGradients(this->GetWeightGradientsAt(1)),
0326 fWeightsUpdateStateGradients(this->GetWeightGradientsAt(4)), fUpdateBiasGradients(this->GetBiasGradientsAt(1)),
0327 fWeightsCandidateGradients(this->GetWeightGradientsAt(2)),
0328 fWeightsCandidateStateGradients(this->GetWeightGradientsAt(5)),
0329 fCandidateBiasGradients(this->GetBiasGradientsAt(2))
0330 {
0331 for (size_t i = 0; i < timeSteps; ++i) {
0332 fDerivativesReset.emplace_back(batchSize, stateSize);
0333 fDerivativesUpdate.emplace_back(batchSize, stateSize);
0334 fDerivativesCandidate.emplace_back(batchSize, stateSize);
0335 reset_gate_value.emplace_back(batchSize, stateSize);
0336 update_gate_value.emplace_back(batchSize, stateSize);
0337 candidate_gate_value.emplace_back(batchSize, stateSize);
0338 }
0339 Architecture_t::InitializeGRUTensors(this);
0340 }
0341
0342
0343 template <typename Architecture_t>
0344 TBasicGRULayer<Architecture_t>::TBasicGRULayer(const TBasicGRULayer &layer)
0345 : VGeneralLayer<Architecture_t>(layer),
0346 fStateSize(layer.fStateSize),
0347 fTimeSteps(layer.fTimeSteps),
0348 fRememberState(layer.fRememberState),
0349 fReturnSequence(layer.fReturnSequence),
0350 fResetGateAfter(layer.fResetGateAfter),
0351 fF1(layer.GetActivationFunctionF1()),
0352 fF2(layer.GetActivationFunctionF2()),
0353 fResetValue(layer.GetBatchSize(), layer.GetStateSize()),
0354 fUpdateValue(layer.GetBatchSize(), layer.GetStateSize()),
0355 fCandidateValue(layer.GetBatchSize(), layer.GetStateSize()),
0356 fState(layer.GetBatchSize(), layer.GetStateSize()),
0357 fWeightsResetGate(this->GetWeightsAt(0)),
0358 fWeightsResetGateState(this->GetWeightsAt(3)),
0359 fResetGateBias(this->GetBiasesAt(0)),
0360 fWeightsUpdateGate(this->GetWeightsAt(1)),
0361 fWeightsUpdateGateState(this->GetWeightsAt(4)),
0362 fUpdateGateBias(this->GetBiasesAt(1)),
0363 fWeightsCandidate(this->GetWeightsAt(2)),
0364 fWeightsCandidateState(this->GetWeightsAt(5)),
0365 fCandidateBias(this->GetBiasesAt(2)),
0366 fWeightsResetGradients(this->GetWeightGradientsAt(0)),
0367 fWeightsResetStateGradients(this->GetWeightGradientsAt(3)),
0368 fResetBiasGradients(this->GetBiasGradientsAt(0)),
0369 fWeightsUpdateGradients(this->GetWeightGradientsAt(1)),
0370 fWeightsUpdateStateGradients(this->GetWeightGradientsAt(4)),
0371 fUpdateBiasGradients(this->GetBiasGradientsAt(1)),
0372 fWeightsCandidateGradients(this->GetWeightGradientsAt(2)),
0373 fWeightsCandidateStateGradients(this->GetWeightGradientsAt(5)),
0374 fCandidateBiasGradients(this->GetBiasGradientsAt(2))
0375 {
0376 for (size_t i = 0; i < fTimeSteps; ++i) {
0377 fDerivativesReset.emplace_back(layer.GetBatchSize(), layer.GetStateSize());
0378 Architecture_t::Copy(fDerivativesReset[i], layer.GetResetDerivativesAt(i));
0379
0380 fDerivativesUpdate.emplace_back(layer.GetBatchSize(), layer.GetStateSize());
0381 Architecture_t::Copy(fDerivativesUpdate[i], layer.GetUpdateDerivativesAt(i));
0382
0383 fDerivativesCandidate.emplace_back(layer.GetBatchSize(), layer.GetStateSize());
0384 Architecture_t::Copy(fDerivativesCandidate[i], layer.GetCandidateDerivativesAt(i));
0385
0386 reset_gate_value.emplace_back(layer.GetBatchSize(), layer.GetStateSize());
0387 Architecture_t::Copy(reset_gate_value[i], layer.GetResetGateTensorAt(i));
0388
0389 update_gate_value.emplace_back(layer.GetBatchSize(), layer.GetStateSize());
0390 Architecture_t::Copy(update_gate_value[i], layer.GetUpdateGateTensorAt(i));
0391
0392 candidate_gate_value.emplace_back(layer.GetBatchSize(), layer.GetStateSize());
0393 Architecture_t::Copy(candidate_gate_value[i], layer.GetCandidateGateTensorAt(i));
0394 }
0395
0396
0397 Architecture_t::Copy(fState, layer.GetState());
0398
0399
0400 Architecture_t::Copy(fResetValue, layer.GetResetGateValue());
0401 Architecture_t::Copy(fCandidateValue, layer.GetCandidateValue());
0402 Architecture_t::Copy(fUpdateValue, layer.GetUpdateGateValue());
0403
0404 Architecture_t::InitializeGRUTensors(this);
0405 }
0406
0407
0408 template <typename Architecture_t>
0409 void TBasicGRULayer<Architecture_t>::Initialize()
0410 {
0411 VGeneralLayer<Architecture_t>::Initialize();
0412
0413 Architecture_t::InitializeGRUDescriptors(fDescriptors, this);
0414 Architecture_t::InitializeGRUWorkspace(fWorkspace, fDescriptors, this);
0415
0416
0417 if (Architecture_t::IsCudnn())
0418 fResetGateAfter = true;
0419 }
0420
0421
0422 template <typename Architecture_t>
0423 auto inline TBasicGRULayer<Architecture_t>::ResetGate(const Matrix_t &input, Matrix_t &dr)
0424 -> void
0425 {
0426
0427
0428
0429 const DNN::EActivationFunction fRst = this->GetActivationFunctionF1();
0430 Matrix_t tmpState(fResetValue.GetNrows(), fResetValue.GetNcols());
0431 Architecture_t::MultiplyTranspose(tmpState, fState, fWeightsResetGateState);
0432 Architecture_t::MultiplyTranspose(fResetValue, input, fWeightsResetGate);
0433 Architecture_t::ScaleAdd(fResetValue, tmpState);
0434 Architecture_t::AddRowWise(fResetValue, fResetGateBias);
0435 DNN::evaluateDerivativeMatrix<Architecture_t>(dr, fRst, fResetValue);
0436 DNN::evaluateMatrix<Architecture_t>(fResetValue, fRst);
0437 }
0438
0439
0440 template <typename Architecture_t>
0441 auto inline TBasicGRULayer<Architecture_t>::UpdateGate(const Matrix_t &input, Matrix_t &du)
0442 -> void
0443 {
0444
0445
0446
0447 const DNN::EActivationFunction fUpd = this->GetActivationFunctionF1();
0448 Matrix_t tmpState(fUpdateValue.GetNrows(), fUpdateValue.GetNcols());
0449 Architecture_t::MultiplyTranspose(tmpState, fState, fWeightsUpdateGateState);
0450 Architecture_t::MultiplyTranspose(fUpdateValue, input, fWeightsUpdateGate);
0451 Architecture_t::ScaleAdd(fUpdateValue, tmpState);
0452 Architecture_t::AddRowWise(fUpdateValue, fUpdateGateBias);
0453 DNN::evaluateDerivativeMatrix<Architecture_t>(du, fUpd, fUpdateValue);
0454 DNN::evaluateMatrix<Architecture_t>(fUpdateValue, fUpd);
0455 }
0456
0457
0458 template <typename Architecture_t>
0459 auto inline TBasicGRULayer<Architecture_t>::CandidateValue(const Matrix_t &input, Matrix_t &dc)
0460 -> void
0461 {
0462
0463
0464
0465
0466
0467
0468
0469
0470
0471
0472
0473
0474 const DNN::EActivationFunction fCan = this->GetActivationFunctionF2();
0475 Matrix_t tmp(fCandidateValue.GetNrows(), fCandidateValue.GetNcols());
0476 if (!fResetGateAfter) {
0477 Matrix_t tmpState(fResetValue);
0478 Architecture_t::Hadamard(tmpState, fState);
0479 Architecture_t::MultiplyTranspose(tmp, tmpState, fWeightsCandidateState);
0480 } else {
0481
0482 Architecture_t::MultiplyTranspose(tmp, fState, fWeightsCandidateState);
0483 Architecture_t::Hadamard(tmp, fResetValue);
0484 }
0485 Architecture_t::MultiplyTranspose(fCandidateValue, input, fWeightsCandidate);
0486 Architecture_t::ScaleAdd(fCandidateValue, tmp);
0487 Architecture_t::AddRowWise(fCandidateValue, fCandidateBias);
0488 DNN::evaluateDerivativeMatrix<Architecture_t>(dc, fCan, fCandidateValue);
0489 DNN::evaluateMatrix<Architecture_t>(fCandidateValue, fCan);
0490 }
0491
0492
0493 template <typename Architecture_t>
0494 auto inline TBasicGRULayer<Architecture_t>::Forward(Tensor_t &input, bool isTraining )
0495 -> void
0496 {
0497
0498 if (Architecture_t::IsCudnn()) {
0499
0500
0501 assert(input.GetStrides()[1] == this->GetInputSize());
0502
0503 Tensor_t &x = this->fX;
0504 Tensor_t &y = this->fY;
0505 Architecture_t::Rearrange(x, input);
0506
0507
0508 const auto &weights = this->GetWeightsTensor();
0509
0510 auto &hx = this->fState;
0511 auto &cx = this->fCell;
0512
0513 auto &hy = this->fState;
0514 auto &cy = this->fCell;
0515
0516 auto & rnnDesc = static_cast<RNNDescriptors_t &>(*fDescriptors);
0517 auto & rnnWork = static_cast<RNNWorkspace_t &>(*fWorkspace);
0518
0519 Architecture_t::RNNForward(x, hx, cx, weights, y, hy, cy, rnnDesc, rnnWork, isTraining);
0520
0521 if (fReturnSequence) {
0522 Architecture_t::Rearrange(this->GetOutput(), y);
0523 } else {
0524
0525 Tensor_t tmp = (y.At(y.GetShape()[0] - 1)).Reshape({y.GetShape()[1], 1, y.GetShape()[2]});
0526 Architecture_t::Copy(this->GetOutput(), tmp);
0527 }
0528
0529 return;
0530 }
0531
0532
0533
0534
0535
0536
0537 Tensor_t arrInput ( fTimeSteps, this->GetBatchSize(), this->GetInputWidth());
0538
0539
0540
0541 Architecture_t::Rearrange(arrInput, input);
0542
0543 Tensor_t arrOutput ( fTimeSteps, this->GetBatchSize(), fStateSize );
0544
0545
0546
0547
0548 if (!this->fRememberState) {
0549 InitState(DNN::EInitialization::kZero);
0550 }
0551
0552
0553
0554 for (size_t t = 0; t < fTimeSteps; ++t) {
0555
0556 ResetGate(arrInput[t], fDerivativesReset[t]);
0557 Architecture_t::Copy(this->GetResetGateTensorAt(t), fResetValue);
0558 UpdateGate(arrInput[t], fDerivativesUpdate[t]);
0559 Architecture_t::Copy(this->GetUpdateGateTensorAt(t), fUpdateValue);
0560
0561 CandidateValue(arrInput[t], fDerivativesCandidate[t]);
0562 Architecture_t::Copy(this->GetCandidateGateTensorAt(t), fCandidateValue);
0563
0564
0565 CellForward(fUpdateValue, fCandidateValue);
0566
0567
0568
0569 Matrix_t arrOutputMt = arrOutput[t];
0570 Architecture_t::Copy(arrOutputMt, fState);
0571 }
0572
0573 if (fReturnSequence)
0574 Architecture_t::Rearrange(this->GetOutput(), arrOutput);
0575 else {
0576
0577 Tensor_t tmp = arrOutput.At(fTimeSteps - 1);
0578
0579
0580 tmp = tmp.Reshape({tmp.GetShape()[0], tmp.GetShape()[1], 1});
0581 assert(tmp.GetSize() == this->GetOutput().GetSize());
0582 assert(tmp.GetShape()[0] == this->GetOutput().GetShape()[2]);
0583 Architecture_t::Rearrange(this->GetOutput(), tmp);
0584
0585 fY = arrOutput;
0586 }
0587 }
0588
0589
0590 template <typename Architecture_t>
0591 auto inline TBasicGRULayer<Architecture_t>::CellForward(Matrix_t &updateGateValues, Matrix_t &candidateValues)
0592 -> void
0593 {
0594 Architecture_t::Hadamard(fState, updateGateValues);
0595
0596
0597 Matrix_t tmp(updateGateValues);
0598 for (size_t j = 0; j < (size_t) tmp.GetNcols(); j++) {
0599 for (size_t i = 0; i < (size_t) tmp.GetNrows(); i++) {
0600 tmp(i,j) = 1 - tmp(i,j);
0601 }
0602 }
0603
0604
0605 Architecture_t::Hadamard(candidateValues, tmp);
0606 Architecture_t::ScaleAdd(fState, candidateValues);
0607 }
0608
0609
0610 template <typename Architecture_t>
0611 auto inline TBasicGRULayer<Architecture_t>::Backward(Tensor_t &gradients_backward,
0612 const Tensor_t &activations_backward)
0613 -> void
0614 {
0615
0616 if (Architecture_t::IsCudnn()) {
0617
0618 Tensor_t &x = this->fX;
0619 Tensor_t &y = this->fY;
0620 Tensor_t &dx = this->fDx;
0621 Tensor_t &dy = this->fDy;
0622
0623
0624 assert(activations_backward.GetStrides()[1] == this->GetInputSize());
0625
0626
0627 Architecture_t::Rearrange(x, activations_backward);
0628
0629 if (!fReturnSequence) {
0630
0631
0632 Architecture_t::InitializeZero(dy);
0633
0634
0635 Tensor_t tmp2 = dy.At(dy.GetShape()[0] - 1).Reshape({dy.GetShape()[1], 1, dy.GetShape()[2]});
0636
0637
0638 Architecture_t::Copy(tmp2, this->GetActivationGradients());
0639 } else {
0640 Architecture_t::Rearrange(y, this->GetOutput());
0641 Architecture_t::Rearrange(dy, this->GetActivationGradients());
0642 }
0643
0644
0645
0646
0647 const auto &weights = this->GetWeightsTensor();
0648 auto &weightGradients = this->GetWeightGradientsTensor();
0649
0650
0651
0652 Architecture_t::InitializeZero(weightGradients);
0653
0654
0655 auto &hx = this->GetState();
0656 auto &cx = this->GetCell();
0657
0658 auto &dhy = hx;
0659 auto &dcy = cx;
0660 auto &dhx = hx;
0661 auto &dcx = cx;
0662
0663 auto & rnnDesc = static_cast<RNNDescriptors_t &>(*fDescriptors);
0664 auto & rnnWork = static_cast<RNNWorkspace_t &>(*fWorkspace);
0665
0666 Architecture_t::RNNBackward(x, hx, cx, y, dy, dhy, dcy, weights, dx, dhx, dcx, weightGradients, rnnDesc, rnnWork);
0667
0668
0669
0670 if (gradients_backward.GetSize() != 0)
0671 Architecture_t::Rearrange(gradients_backward, dx);
0672
0673 return;
0674 }
0675
0676
0677
0678
0679 Matrix_t state_gradients_backward(this->GetBatchSize(), fStateSize);
0680 DNN::initialize<Architecture_t>(state_gradients_backward, DNN::EInitialization::kZero);
0681
0682
0683 bool dummy = false;
0684 if (gradients_backward.GetSize() == 0 || gradients_backward[0].GetNrows() == 0 || gradients_backward[0].GetNcols() == 0) {
0685 dummy = true;
0686 }
0687
0688 Tensor_t arr_gradients_backward ( fTimeSteps, this->GetBatchSize(), this->GetInputSize());
0689
0690
0691
0692
0693 Tensor_t arr_activations_backward ( fTimeSteps, this->GetBatchSize(), this->GetInputSize());
0694
0695 Architecture_t::Rearrange(arr_activations_backward, activations_backward);
0696
0697
0698
0699 Tensor_t arr_output ( fTimeSteps, this->GetBatchSize(), fStateSize);
0700
0701 Matrix_t initState(this->GetBatchSize(), fStateSize);
0702 DNN::initialize<Architecture_t>(initState, DNN::EInitialization::kZero);
0703
0704
0705 Tensor_t arr_actgradients ( fTimeSteps, this->GetBatchSize(), fStateSize);
0706
0707 if (fReturnSequence) {
0708 Architecture_t::Rearrange(arr_output, this->GetOutput());
0709 Architecture_t::Rearrange(arr_actgradients, this->GetActivationGradients());
0710 } else {
0711
0712 arr_output = fY;
0713 Architecture_t::InitializeZero(arr_actgradients);
0714
0715 Tensor_t tmp_grad = arr_actgradients.At(fTimeSteps - 1).Reshape({this->GetBatchSize(), fStateSize, 1});
0716 assert(tmp_grad.GetSize() == this->GetActivationGradients().GetSize());
0717 assert(tmp_grad.GetShape()[0] ==
0718 this->GetActivationGradients().GetShape()[2]);
0719
0720 Architecture_t::Rearrange(tmp_grad, this->GetActivationGradients());
0721 }
0722
0723
0724
0725
0726
0727 fWeightsResetGradients.Zero();
0728 fWeightsResetStateGradients.Zero();
0729 fResetBiasGradients.Zero();
0730
0731
0732 fWeightsUpdateGradients.Zero();
0733 fWeightsUpdateStateGradients.Zero();
0734 fUpdateBiasGradients.Zero();
0735
0736
0737 fWeightsCandidateGradients.Zero();
0738 fWeightsCandidateStateGradients.Zero();
0739 fCandidateBiasGradients.Zero();
0740
0741
0742 for (size_t t = fTimeSteps; t > 0; t--) {
0743
0744 Architecture_t::ScaleAdd(state_gradients_backward, arr_actgradients[t-1]);
0745 if (t > 1) {
0746 const Matrix_t &prevStateActivations = arr_output[t-2];
0747 Matrix_t dx = arr_gradients_backward[t-1];
0748
0749 CellBackward(state_gradients_backward, prevStateActivations,
0750 this->GetResetGateTensorAt(t-1), this->GetUpdateGateTensorAt(t-1),
0751 this->GetCandidateGateTensorAt(t-1),
0752 arr_activations_backward[t-1], dx ,
0753 fDerivativesReset[t-1], fDerivativesUpdate[t-1],
0754 fDerivativesCandidate[t-1]);
0755 } else {
0756 const Matrix_t &prevStateActivations = initState;
0757 Matrix_t dx = arr_gradients_backward[t-1];
0758 CellBackward(state_gradients_backward, prevStateActivations,
0759 this->GetResetGateTensorAt(t-1), this->GetUpdateGateTensorAt(t-1),
0760 this->GetCandidateGateTensorAt(t-1),
0761 arr_activations_backward[t-1], dx ,
0762 fDerivativesReset[t-1], fDerivativesUpdate[t-1],
0763 fDerivativesCandidate[t-1]);
0764 }
0765 }
0766
0767 if (!dummy) {
0768 Architecture_t::Rearrange(gradients_backward, arr_gradients_backward );
0769 }
0770
0771 }
0772
0773
0774
0775 template <typename Architecture_t>
0776 auto inline TBasicGRULayer<Architecture_t>::CellBackward(Matrix_t & state_gradients_backward,
0777 const Matrix_t & precStateActivations,
0778 const Matrix_t & reset_gate, const Matrix_t & update_gate,
0779 const Matrix_t & candidate_gate,
0780 const Matrix_t & input, Matrix_t & input_gradient,
0781 Matrix_t &dr, Matrix_t &du, Matrix_t &dc)
0782 -> Matrix_t &
0783 {
0784
0785
0786 return Architecture_t::GRULayerBackward(state_gradients_backward,
0787 fWeightsResetGradients, fWeightsUpdateGradients, fWeightsCandidateGradients,
0788 fWeightsResetStateGradients, fWeightsUpdateStateGradients,
0789 fWeightsCandidateStateGradients, fResetBiasGradients, fUpdateBiasGradients,
0790 fCandidateBiasGradients, dr, du, dc,
0791 precStateActivations,
0792 reset_gate, update_gate, candidate_gate,
0793 fWeightsResetGate, fWeightsUpdateGate, fWeightsCandidate,
0794 fWeightsResetGateState, fWeightsUpdateGateState, fWeightsCandidateState,
0795 input, input_gradient, fResetGateAfter);
0796 }
0797
0798
0799
0800 template <typename Architecture_t>
0801 auto TBasicGRULayer<Architecture_t>::InitState(DNN::EInitialization )
0802 -> void
0803 {
0804 DNN::initialize<Architecture_t>(this->GetState(), DNN::EInitialization::kZero);
0805 }
0806
0807
0808 template<typename Architecture_t>
0809 auto TBasicGRULayer<Architecture_t>::Print() const
0810 -> void
0811 {
0812 std::cout << " GRU Layer: \t ";
0813 std::cout << " (NInput = " << this->GetInputSize();
0814 std::cout << ", NState = " << this->GetStateSize();
0815 std::cout << ", NTime = " << this->GetTimeSteps() << " )";
0816 std::cout << "\tOutput = ( " << this->GetOutput().GetFirstSize() << " , " << this->GetOutput()[0].GetNrows() << " , " << this->GetOutput()[0].GetNcols() << " )\n";
0817 }
0818
0819
0820 template <typename Architecture_t>
0821 auto inline TBasicGRULayer<Architecture_t>::AddWeightsXMLTo(void *parent)
0822 -> void
0823 {
0824 auto layerxml = gTools().xmlengine().NewChild(parent, nullptr, "GRULayer");
0825
0826
0827 gTools().xmlengine().NewAttr(layerxml, nullptr, "StateSize", gTools().StringFromInt(this->GetStateSize()));
0828 gTools().xmlengine().NewAttr(layerxml, nullptr, "InputSize", gTools().StringFromInt(this->GetInputSize()));
0829 gTools().xmlengine().NewAttr(layerxml, nullptr, "TimeSteps", gTools().StringFromInt(this->GetTimeSteps()));
0830 gTools().xmlengine().NewAttr(layerxml, nullptr, "RememberState", gTools().StringFromInt(this->DoesRememberState()));
0831 gTools().xmlengine().NewAttr(layerxml, nullptr, "ReturnSequence", gTools().StringFromInt(this->DoesReturnSequence()));
0832 gTools().xmlengine().NewAttr(layerxml, nullptr, "ResetGateAfter", gTools().StringFromInt(this->fResetGateAfter));
0833
0834
0835 this->WriteMatrixToXML(layerxml, "ResetWeights", this->GetWeightsAt(0));
0836 this->WriteMatrixToXML(layerxml, "ResetStateWeights", this->GetWeightsAt(1));
0837 this->WriteMatrixToXML(layerxml, "ResetBiases", this->GetBiasesAt(0));
0838 this->WriteMatrixToXML(layerxml, "UpdateWeights", this->GetWeightsAt(2));
0839 this->WriteMatrixToXML(layerxml, "UpdateStateWeights", this->GetWeightsAt(3));
0840 this->WriteMatrixToXML(layerxml, "UpdateBiases", this->GetBiasesAt(1));
0841 this->WriteMatrixToXML(layerxml, "CandidateWeights", this->GetWeightsAt(4));
0842 this->WriteMatrixToXML(layerxml, "CandidateStateWeights", this->GetWeightsAt(5));
0843 this->WriteMatrixToXML(layerxml, "CandidateBiases", this->GetBiasesAt(2));
0844 }
0845
0846
0847 template <typename Architecture_t>
0848 auto inline TBasicGRULayer<Architecture_t>::ReadWeightsFromXML(void *parent)
0849 -> void
0850 {
0851
0852 this->ReadMatrixXML(parent, "ResetWeights", this->GetWeightsAt(0));
0853 this->ReadMatrixXML(parent, "ResetStateWeights", this->GetWeightsAt(1));
0854 this->ReadMatrixXML(parent, "ResetBiases", this->GetBiasesAt(0));
0855 this->ReadMatrixXML(parent, "UpdateWeights", this->GetWeightsAt(2));
0856 this->ReadMatrixXML(parent, "UpdateStateWeights", this->GetWeightsAt(3));
0857 this->ReadMatrixXML(parent, "UpdateBiases", this->GetBiasesAt(1));
0858 this->ReadMatrixXML(parent, "CandidateWeights", this->GetWeightsAt(4));
0859 this->ReadMatrixXML(parent, "CandidateStateWeights", this->GetWeightsAt(5));
0860 this->ReadMatrixXML(parent, "CandidateBiases", this->GetBiasesAt(2));
0861 }
0862
0863 }
0864 }
0865 }
0866
0867 #endif