Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-18 10:10:54

0001 // @(#)root/tmva/tmva/dnn/gru:$Id$
0002 // Author: Surya S Dwivedi 03/07/19
0003 
0004 /**********************************************************************************
0005  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis       *
0006  * Package: TMVA                                                                  *
0007  * Class : BasicGRULayer                                                         *
0008  *                                                                                *
0009  * Description:                                                                   *
0010  *       NeuralNetwork                                                            *
0011  *                                                                                *
0012  * Authors (alphabetical):                                                        *
0013  *       Surya S Dwivedi  <surya2191997@gmail.com> - IIT Kharagpur, India         *
0014  *                                                                                *
0015  * Copyright (c) 2005-2019:                                                       *
0016  * All rights reserved.                                                           *
0017  *       CERN, Switzerland                                                        *
0018  *                                                                                *
0019  * For the licensing terms see $ROOTSYS/LICENSE.                                  *
0020  * For the list of contributors see $ROOTSYS/README/CREDITS.                      *
0021  **********************************************************************************/
0022 
0023 //#pragma once
0024 
0025 //////////////////////////////////////////////////////////////////////
0026 // This class implements the GRU layer. GRU is a variant of vanilla
0027 // RNN which is capable of learning long range dependencies.
0028 //////////////////////////////////////////////////////////////////////
0029 
0030 #ifndef TMVA_DNN_GRU_LAYER
0031 #define TMVA_DNN_GRU_LAYER
0032 
0033 #include <cmath>
0034 #include <iostream>
0035 #include <vector>
0036 
0037 #include "TMatrix.h"
0038 #include "TMVA/DNN/Functions.h"
0039 
0040 namespace TMVA
0041 {
0042 namespace DNN
0043 {
0044 namespace RNN
0045 {
0046 
0047 //______________________________________________________________________________
0048 //
0049 // Basic GRU Layer
0050 //______________________________________________________________________________
0051 
0052 /** \class BasicGRULayer
0053       Generic implementation
0054 */
0055 template<typename Architecture_t>
0056       class TBasicGRULayer : public VGeneralLayer<Architecture_t>
0057 {
0058 
0059 public:
0060 
0061    using Matrix_t = typename Architecture_t::Matrix_t;
0062    using Scalar_t = typename Architecture_t::Scalar_t;
0063    using Tensor_t = typename Architecture_t::Tensor_t;
0064 
0065    using LayerDescriptor_t = typename Architecture_t::RecurrentDescriptor_t;
0066    using WeightsDescriptor_t = typename Architecture_t::FilterDescriptor_t;
0067    using TensorDescriptor_t = typename Architecture_t::TensorDescriptor_t;
0068    using HelperDescriptor_t = typename Architecture_t::DropoutDescriptor_t;
0069 
0070    using RNNWorkspace_t = typename Architecture_t::RNNWorkspace_t;
0071    using RNNDescriptors_t = typename Architecture_t::RNNDescriptors_t;
0072 
0073 private:
0074 
0075    size_t fStateSize;                           ///< Hidden state size for GRU
0076    size_t fTimeSteps;                           ///< Timesteps for GRU
0077 
0078    bool fRememberState;                         ///< Remember state in next pass
0079    bool fReturnSequence = false;                ///< Return in output full sequence or just last element
0080    bool fResetGateAfter = false;                ///< GRU variant to Apply the reset gate multiplication afterwards (used by cuDNN)
0081 
0082    DNN::EActivationFunction fF1;                ///< Activation function: sigmoid
0083    DNN::EActivationFunction fF2;                ///< Activation function: tanh
0084 
0085    Matrix_t fResetValue;                        ///< Computed reset gate values
0086    Matrix_t fUpdateValue;                       ///< Computed forget gate values
0087    Matrix_t fCandidateValue;                    ///< Computed candidate values
0088    Matrix_t fState;                             ///< Hidden state of GRU
0089 
0090 
0091    Matrix_t &fWeightsResetGate;                 ///< Reset Gate weights for input, fWeights[0]
0092    Matrix_t &fWeightsResetGateState;            ///< Input Gate weights for prev state, fWeights[1]
0093    Matrix_t &fResetGateBias;                    ///< Input Gate bias
0094 
0095    Matrix_t &fWeightsUpdateGate;                ///< Update Gate weights for input, fWeights[2]
0096    Matrix_t &fWeightsUpdateGateState;           ///< Update Gate weights for prev state, fWeights[3]
0097    Matrix_t &fUpdateGateBias;                   ///< Update Gate bias
0098 
0099    Matrix_t &fWeightsCandidate;                 ///< Candidate Gate weights for input, fWeights[4]
0100    Matrix_t &fWeightsCandidateState;            ///< Candidate Gate weights for prev state, fWeights[5]
0101    Matrix_t &fCandidateBias;                    ///< Candidate Gate bias
0102 
0103 
0104    std::vector<Matrix_t> reset_gate_value;      ///< Reset gate value for every time step
0105    std::vector<Matrix_t> update_gate_value;     ///< Update gate value for every time step
0106    std::vector<Matrix_t> candidate_gate_value;  ///< Candidate gate value for every time step
0107 
0108    std::vector<Matrix_t> fDerivativesReset;     ///< First fDerivatives of the activations reset gate
0109    std::vector<Matrix_t> fDerivativesUpdate;    ///< First fDerivatives of the activations update gate
0110    std::vector<Matrix_t> fDerivativesCandidate; ///< First fDerivatives of the activations candidate gate
0111 
0112    Matrix_t &fWeightsResetGradients;            ///< Gradients w.r.t the reset gate - input weights
0113    Matrix_t &fWeightsResetStateGradients;       ///< Gradients w.r.t the reset gate - hidden state weights
0114    Matrix_t &fResetBiasGradients;               ///< Gradients w.r.t the reset gate - bias weights
0115    Matrix_t &fWeightsUpdateGradients;           ///< Gradients w.r.t the update gate - input weights
0116    Matrix_t &fWeightsUpdateStateGradients;      ///< Gradients w.r.t the update gate - hidden state weights
0117    Matrix_t &fUpdateBiasGradients;              ///< Gradients w.r.t the update gate - bias weights
0118    Matrix_t &fWeightsCandidateGradients;        ///< Gradients w.r.t the candidate gate - input weights
0119    Matrix_t &fWeightsCandidateStateGradients;   ///< Gradients w.r.t the candidate gate - hidden state weights
0120    Matrix_t &fCandidateBiasGradients;           ///< Gradients w.r.t the candidate gate - bias weights
0121 
0122    Matrix_t fCell;                              ///< Empty matrix for GRU
0123 
0124    // Tensor representing all weights (used by cuDNN)
0125    Tensor_t fWeightsTensor;         ///< Tensor for all weights
0126    Tensor_t fWeightGradientsTensor; ///< Tensor for all weight gradients
0127 
0128    // tensors used internally for the forward and backward pass
0129    Tensor_t fX;  ///<  cached input tensor as T x B x I
0130    Tensor_t fY;  ///<  cached output tensor as T x B x S
0131    Tensor_t fDx; ///< cached   gradient on the input (output of backward)   as T x B x I
0132    Tensor_t fDy; ///< cached  activation gradient (input of backward)   as T x B x S
0133 
0134    TDescriptors *fDescriptors = nullptr; ///< Keeps all the RNN descriptors
0135    TWorkspace *fWorkspace = nullptr;     // workspace needed for GPU computation (CudNN)
0136 
0137 public:
0138 
0139    /*! Constructor */
0140    TBasicGRULayer(size_t batchSize, size_t stateSize, size_t inputSize,
0141                    size_t timeSteps, bool rememberState = false, bool returnSequence = false,
0142                    bool resetGateAfter = false,
0143                    DNN::EActivationFunction f1 = DNN::EActivationFunction::kSigmoid,
0144                    DNN::EActivationFunction f2 = DNN::EActivationFunction::kTanh,
0145                    bool training = true, DNN::EInitialization fA = DNN::EInitialization::kZero);
0146 
0147    /*! Copy Constructor */
0148    TBasicGRULayer(const TBasicGRULayer &);
0149 
0150    /*! Initialize the weights according to the given initialization
0151     **  method. */
0152    virtual void Initialize();
0153 
0154    /*! Initialize the hidden state and cell state method. */
0155    void InitState(DNN::EInitialization m = DNN::EInitialization::kZero);
0156 
0157    /*! Computes the next hidden state
0158     *  and next cell state with given input matrix. */
0159    void Forward(Tensor_t &input, bool isTraining = true);
0160 
0161    /*! Forward for a single cell (time unit) */
0162    void CellForward(Matrix_t &updateGateValues, Matrix_t &candidateValues);
0163 
0164    /*! Backpropagates the error. Must only be called directly at the corresponding
0165     *  call to Forward(...). */
0166    void Backward(Tensor_t &gradients_backward,
0167                  const Tensor_t &activations_backward);
0168 
0169    /* Updates weights and biases, given the learning rate */
0170    void Update(const Scalar_t learningRate);
0171 
0172    /*! Backward for a single time unit
0173     *  a the corresponding call to Forward(...). */
0174    Matrix_t & CellBackward(Matrix_t & state_gradients_backward,
0175                            const Matrix_t & precStateActivations,
0176                            const Matrix_t & reset_gate, const Matrix_t & update_gate,
0177                            const Matrix_t & candidate_gate,
0178                            const Matrix_t & input, Matrix_t & input_gradient,
0179                            Matrix_t &dr, Matrix_t &du, Matrix_t &dc);
0180 
0181    /*! Decides the values we'll update (NN with Sigmoid) */
0182    void ResetGate(const Matrix_t &input, Matrix_t &di);
0183 
0184    /*! Forgets the past values (NN with Sigmoid) */
0185    void UpdateGate(const Matrix_t &input, Matrix_t &df);
0186 
0187    /*! Decides the new candidate values (NN with Tanh) */
0188    void CandidateValue(const Matrix_t &input, Matrix_t &dc);
0189 
0190    /*! Prints the info about the layer */
0191    void Print() const;
0192 
0193    /*! Writes the information and the weights about the layer in an XML node. */
0194    void AddWeightsXMLTo(void *parent);
0195 
0196    /*! Read the information and the weights about the layer from XML node. */
0197    void ReadWeightsFromXML(void *parent);
0198 
0199    /*! Getters */
0200    size_t GetInputSize()               const { return this->GetInputWidth(); }
0201    size_t GetTimeSteps()               const { return fTimeSteps; }
0202    size_t GetStateSize()               const { return fStateSize; }
0203 
0204    inline bool DoesRememberState()       const { return fRememberState; }
0205    inline bool DoesReturnSequence() const { return fReturnSequence; }
0206 
0207    inline DNN::EActivationFunction     GetActivationFunctionF1()        const { return fF1; }
0208    inline DNN::EActivationFunction     GetActivationFunctionF2()        const { return fF2; }
0209 
0210    const Matrix_t                    & GetResetGateValue()                const { return fResetValue; }
0211    Matrix_t                          & GetResetGateValue()                      { return fResetValue; }
0212    const Matrix_t                    & GetCandidateValue()                const { return fCandidateValue; }
0213    Matrix_t                          & GetCandidateValue()                      { return fCandidateValue; }
0214    const Matrix_t                    & GetUpdateGateValue()               const { return fUpdateValue; }
0215    Matrix_t                          & GetUpdateGateValue()                     { return fUpdateValue; }
0216 
0217    const Matrix_t                    & GetState()                   const { return fState; }
0218    Matrix_t                          & GetState()                         { return fState; }
0219    const Matrix_t                    &GetCell()                     const { return fCell; }
0220    Matrix_t                          & GetCell()                          { return  fCell; }
0221 
0222    const Matrix_t                    & GetWeightsResetGate()              const { return fWeightsResetGate; }
0223    Matrix_t                          & GetWeightsResetGate()                    { return fWeightsResetGate; }
0224    const Matrix_t                    & GetWeightsCandidate()              const { return fWeightsCandidate; }
0225    Matrix_t                          & GetWeightsCandidate()                    { return fWeightsCandidate; }
0226    const Matrix_t                    & GetWeightsUpdateGate()             const { return fWeightsUpdateGate; }
0227    Matrix_t                          & GetWeightsUpdateGate()                   { return fWeightsUpdateGate; }
0228 
0229    const Matrix_t                    & GetWeightsResetGateState()         const { return fWeightsResetGateState; }
0230    Matrix_t                          & GetWeightsResetGateState()               { return fWeightsResetGateState; }
0231    const Matrix_t                    & GetWeightsUpdateGateState()        const { return fWeightsUpdateGateState; }
0232    Matrix_t                          & GetWeightsUpdateGateState()              { return fWeightsUpdateGateState; }
0233    const Matrix_t                    & GetWeightsCandidateState()         const { return fWeightsCandidateState; }
0234    Matrix_t                          & GetWeightsCandidateState()               { return fWeightsCandidateState; }
0235 
0236    const std::vector<Matrix_t>       & GetDerivativesReset()              const { return fDerivativesReset; }
0237    std::vector<Matrix_t>             & GetDerivativesReset()                    { return fDerivativesReset; }
0238    const Matrix_t                    & GetResetDerivativesAt(size_t i)    const { return fDerivativesReset[i]; }
0239    Matrix_t                          & GetResetDerivativesAt(size_t i)           { return fDerivativesReset[i]; }
0240    const std::vector<Matrix_t>       & GetDerivativesUpdate()              const { return fDerivativesUpdate; }
0241    std::vector<Matrix_t>             & GetDerivativesUpdate()                    { return fDerivativesUpdate; }
0242    const Matrix_t                    & GetUpdateDerivativesAt(size_t i)    const { return fDerivativesUpdate[i]; }
0243    Matrix_t                          & GetUpdateDerivativesAt(size_t i)          { return fDerivativesUpdate[i]; }
0244    const std::vector<Matrix_t>       & GetDerivativesCandidate()           const { return fDerivativesCandidate; }
0245    std::vector<Matrix_t>             & GetDerivativesCandidate()                 { return fDerivativesCandidate; }
0246    const Matrix_t                    & GetCandidateDerivativesAt(size_t i) const { return fDerivativesCandidate[i]; }
0247    Matrix_t                          & GetCandidateDerivativesAt(size_t i)       { return fDerivativesCandidate[i]; }
0248 
0249    const std::vector<Matrix_t>       & GetResetGateTensor()              const { return reset_gate_value; }
0250    std::vector<Matrix_t>             & GetResetGateTensor()                    { return reset_gate_value; }
0251    const Matrix_t                    & GetResetGateTensorAt(size_t i)    const { return reset_gate_value[i]; }
0252    Matrix_t                          & GetResetGateTensorAt(size_t i)           { return reset_gate_value[i]; }
0253    const std::vector<Matrix_t>       & GetUpdateGateTensor()              const { return update_gate_value; }
0254    std::vector<Matrix_t>             & GetUpdateGateTensor()                    { return update_gate_value; }
0255    const Matrix_t                    & GetUpdateGateTensorAt(size_t i)    const { return update_gate_value[i]; }
0256    Matrix_t                          & GetUpdateGateTensorAt(size_t i)          { return update_gate_value[i]; }
0257    const std::vector<Matrix_t>       & GetCandidateGateTensor()           const { return candidate_gate_value; }
0258    std::vector<Matrix_t>             & GetCandidateGateTensor()                 { return candidate_gate_value; }
0259    const Matrix_t                    & GetCandidateGateTensorAt(size_t i) const { return candidate_gate_value[i]; }
0260    Matrix_t                          & GetCandidateGateTensorAt(size_t i)       { return candidate_gate_value[i]; }
0261 
0262 
0263 
0264    const Matrix_t                   & GetResetGateBias()         const { return fResetGateBias; }
0265    Matrix_t                         & GetResetGateBias()               { return fResetGateBias; }
0266    const Matrix_t                   & GetUpdateGateBias()        const { return fUpdateGateBias; }
0267    Matrix_t                         & GetUpdateGateBias()              { return fUpdateGateBias; }
0268    const Matrix_t                   & GetCandidateBias()         const { return fCandidateBias; }
0269    Matrix_t                         & GetCandidateBias()               { return fCandidateBias; }
0270 
0271    const Matrix_t                   & GetWeightsResetGradients()        const { return fWeightsResetGradients; }
0272    Matrix_t                         & GetWeightsResetGradients()              { return fWeightsResetGradients; }
0273    const Matrix_t                   & GetWeightsResetStateGradients()   const { return fWeightsResetStateGradients; }
0274    Matrix_t                         & GetWeightsResetStateGradients()         { return fWeightsResetStateGradients; }
0275    const Matrix_t                   & GetResetBiasGradients()           const { return fResetBiasGradients; }
0276    Matrix_t                         & GetResetBiasGradients()                 { return fResetBiasGradients; }
0277    const Matrix_t                   & GetWeightsUpdateGradients()      const { return fWeightsUpdateGradients; }
0278    Matrix_t                         & GetWeightsUpdateGradients()            { return fWeightsUpdateGradients; }
0279    const Matrix_t                   & GetWeigthsUpdateStateGradients() const { return fWeightsUpdateStateGradients; }
0280    Matrix_t                         & GetWeightsUpdateStateGradients()       { return fWeightsUpdateStateGradients; }
0281    const Matrix_t                   & GetUpdateBiasGradients()         const { return fUpdateBiasGradients; }
0282    Matrix_t                         & GetUpdateBiasGradients()               { return fUpdateBiasGradients; }
0283    const Matrix_t                   & GetWeightsCandidateGradients()      const { return fWeightsCandidateGradients; }
0284    Matrix_t                         & GetWeightsCandidateGradients()            { return fWeightsCandidateGradients; }
0285    const Matrix_t                   & GetWeightsCandidateStateGradients() const { return fWeightsCandidateStateGradients; }
0286    Matrix_t                         & GetWeightsCandidateStateGradients()       { return fWeightsCandidateStateGradients; }
0287    const Matrix_t                   & GetCandidateBiasGradients()         const { return fCandidateBiasGradients; }
0288    Matrix_t                         & GetCandidateBiasGradients()               { return fCandidateBiasGradients; }
0289 
0290    Tensor_t &GetWeightsTensor() { return fWeightsTensor; }
0291    const Tensor_t &GetWeightsTensor() const { return fWeightsTensor; }
0292    Tensor_t &GetWeightGradientsTensor() { return fWeightGradientsTensor; }
0293    const Tensor_t &GetWeightGradientsTensor() const { return fWeightGradientsTensor; }
0294 
0295    Tensor_t &GetX() { return fX; }
0296    Tensor_t &GetY() { return fY; }
0297    Tensor_t &GetDX() { return fDx; }
0298    Tensor_t &GetDY() { return fDy; }
0299 };
0300 
0301 
0302 //______________________________________________________________________________
0303 //
0304 // Basic GRU-Layer Implementation
0305 //______________________________________________________________________________
0306 
0307 template <typename Architecture_t>
0308 TBasicGRULayer<Architecture_t>::TBasicGRULayer(size_t batchSize, size_t stateSize, size_t inputSize, size_t timeSteps,
0309                                                bool rememberState, bool returnSequence, bool resetGateAfter, DNN::EActivationFunction f1,
0310                                                DNN::EActivationFunction f2, bool /* training */,
0311                                                DNN::EInitialization fA)
0312    : VGeneralLayer<Architecture_t>(batchSize, 1, timeSteps, inputSize, 1, (returnSequence) ? timeSteps : 1, stateSize,
0313                                    6, {stateSize, stateSize, stateSize, stateSize, stateSize, stateSize},
0314                                    {inputSize, inputSize, inputSize, stateSize, stateSize, stateSize}, 3,
0315                                    {stateSize, stateSize, stateSize}, {1, 1, 1}, batchSize,
0316                                    (returnSequence) ? timeSteps : 1, stateSize, fA),
0317      fStateSize(stateSize), fTimeSteps(timeSteps), fRememberState(rememberState), fReturnSequence(returnSequence), fResetGateAfter(resetGateAfter),
0318      fF1(f1), fF2(f2), fResetValue(batchSize, stateSize), fUpdateValue(batchSize, stateSize),
0319      fCandidateValue(batchSize, stateSize), fState(batchSize, stateSize), fWeightsResetGate(this->GetWeightsAt(0)),
0320      fWeightsResetGateState(this->GetWeightsAt(3)), fResetGateBias(this->GetBiasesAt(0)),
0321      fWeightsUpdateGate(this->GetWeightsAt(1)), fWeightsUpdateGateState(this->GetWeightsAt(4)),
0322      fUpdateGateBias(this->GetBiasesAt(1)), fWeightsCandidate(this->GetWeightsAt(2)),
0323      fWeightsCandidateState(this->GetWeightsAt(5)), fCandidateBias(this->GetBiasesAt(2)),
0324      fWeightsResetGradients(this->GetWeightGradientsAt(0)), fWeightsResetStateGradients(this->GetWeightGradientsAt(3)),
0325      fResetBiasGradients(this->GetBiasGradientsAt(0)), fWeightsUpdateGradients(this->GetWeightGradientsAt(1)),
0326      fWeightsUpdateStateGradients(this->GetWeightGradientsAt(4)), fUpdateBiasGradients(this->GetBiasGradientsAt(1)),
0327      fWeightsCandidateGradients(this->GetWeightGradientsAt(2)),
0328      fWeightsCandidateStateGradients(this->GetWeightGradientsAt(5)),
0329      fCandidateBiasGradients(this->GetBiasGradientsAt(2))
0330 {
0331    for (size_t i = 0; i < timeSteps; ++i) {
0332       fDerivativesReset.emplace_back(batchSize, stateSize);
0333       fDerivativesUpdate.emplace_back(batchSize, stateSize);
0334       fDerivativesCandidate.emplace_back(batchSize, stateSize);
0335       reset_gate_value.emplace_back(batchSize, stateSize);
0336       update_gate_value.emplace_back(batchSize, stateSize);
0337       candidate_gate_value.emplace_back(batchSize, stateSize);
0338    }
0339    Architecture_t::InitializeGRUTensors(this);
0340 }
0341 
0342  //______________________________________________________________________________
0343 template <typename Architecture_t>
0344 TBasicGRULayer<Architecture_t>::TBasicGRULayer(const TBasicGRULayer &layer)
0345    : VGeneralLayer<Architecture_t>(layer),
0346       fStateSize(layer.fStateSize),
0347       fTimeSteps(layer.fTimeSteps),
0348       fRememberState(layer.fRememberState),
0349       fReturnSequence(layer.fReturnSequence),
0350       fResetGateAfter(layer.fResetGateAfter),
0351       fF1(layer.GetActivationFunctionF1()),
0352       fF2(layer.GetActivationFunctionF2()),
0353       fResetValue(layer.GetBatchSize(), layer.GetStateSize()),
0354       fUpdateValue(layer.GetBatchSize(), layer.GetStateSize()),
0355       fCandidateValue(layer.GetBatchSize(), layer.GetStateSize()),
0356       fState(layer.GetBatchSize(), layer.GetStateSize()),
0357       fWeightsResetGate(this->GetWeightsAt(0)),
0358       fWeightsResetGateState(this->GetWeightsAt(3)),
0359       fResetGateBias(this->GetBiasesAt(0)),
0360       fWeightsUpdateGate(this->GetWeightsAt(1)),
0361       fWeightsUpdateGateState(this->GetWeightsAt(4)),
0362       fUpdateGateBias(this->GetBiasesAt(1)),
0363       fWeightsCandidate(this->GetWeightsAt(2)),
0364       fWeightsCandidateState(this->GetWeightsAt(5)),
0365       fCandidateBias(this->GetBiasesAt(2)),
0366       fWeightsResetGradients(this->GetWeightGradientsAt(0)),
0367       fWeightsResetStateGradients(this->GetWeightGradientsAt(3)),
0368       fResetBiasGradients(this->GetBiasGradientsAt(0)),
0369       fWeightsUpdateGradients(this->GetWeightGradientsAt(1)),
0370       fWeightsUpdateStateGradients(this->GetWeightGradientsAt(4)),
0371       fUpdateBiasGradients(this->GetBiasGradientsAt(1)),
0372       fWeightsCandidateGradients(this->GetWeightGradientsAt(2)),
0373       fWeightsCandidateStateGradients(this->GetWeightGradientsAt(5)),
0374       fCandidateBiasGradients(this->GetBiasGradientsAt(2))
0375 {
0376    for (size_t i = 0; i < fTimeSteps; ++i) {
0377       fDerivativesReset.emplace_back(layer.GetBatchSize(), layer.GetStateSize());
0378       Architecture_t::Copy(fDerivativesReset[i], layer.GetResetDerivativesAt(i));
0379 
0380       fDerivativesUpdate.emplace_back(layer.GetBatchSize(), layer.GetStateSize());
0381       Architecture_t::Copy(fDerivativesUpdate[i], layer.GetUpdateDerivativesAt(i));
0382 
0383       fDerivativesCandidate.emplace_back(layer.GetBatchSize(), layer.GetStateSize());
0384       Architecture_t::Copy(fDerivativesCandidate[i], layer.GetCandidateDerivativesAt(i));
0385 
0386       reset_gate_value.emplace_back(layer.GetBatchSize(), layer.GetStateSize());
0387       Architecture_t::Copy(reset_gate_value[i], layer.GetResetGateTensorAt(i));
0388 
0389       update_gate_value.emplace_back(layer.GetBatchSize(), layer.GetStateSize());
0390       Architecture_t::Copy(update_gate_value[i], layer.GetUpdateGateTensorAt(i));
0391 
0392       candidate_gate_value.emplace_back(layer.GetBatchSize(), layer.GetStateSize());
0393       Architecture_t::Copy(candidate_gate_value[i], layer.GetCandidateGateTensorAt(i));
0394    }
0395 
0396    // Gradient matrices not copied
0397    Architecture_t::Copy(fState, layer.GetState());
0398 
0399    // Copy each gate values.
0400    Architecture_t::Copy(fResetValue, layer.GetResetGateValue());
0401    Architecture_t::Copy(fCandidateValue, layer.GetCandidateValue());
0402    Architecture_t::Copy(fUpdateValue, layer.GetUpdateGateValue());
0403 
0404    Architecture_t::InitializeGRUTensors(this);
0405 }
0406 
0407 //______________________________________________________________________________
0408 template <typename Architecture_t>
0409 void TBasicGRULayer<Architecture_t>::Initialize()
0410 {
0411    VGeneralLayer<Architecture_t>::Initialize();
0412 
0413    Architecture_t::InitializeGRUDescriptors(fDescriptors, this);
0414    Architecture_t::InitializeGRUWorkspace(fWorkspace, fDescriptors, this);
0415 
0416    //cuDNN only supports resetGate after
0417    if (Architecture_t::IsCudnn())
0418       fResetGateAfter = true;
0419 }
0420 
0421 //______________________________________________________________________________
0422 template <typename Architecture_t>
0423 auto inline TBasicGRULayer<Architecture_t>::ResetGate(const Matrix_t &input, Matrix_t &dr)
0424 -> void
0425 {
0426    /*! Computes reset gate values according to equation:
0427     *  input = act(W_input . input + W_state . state + bias)
0428     *  activation function: sigmoid. */
0429    const DNN::EActivationFunction fRst = this->GetActivationFunctionF1();
0430    Matrix_t tmpState(fResetValue.GetNrows(), fResetValue.GetNcols());
0431    Architecture_t::MultiplyTranspose(tmpState, fState, fWeightsResetGateState);
0432    Architecture_t::MultiplyTranspose(fResetValue, input, fWeightsResetGate);
0433    Architecture_t::ScaleAdd(fResetValue, tmpState);
0434    Architecture_t::AddRowWise(fResetValue, fResetGateBias);
0435    DNN::evaluateDerivativeMatrix<Architecture_t>(dr, fRst, fResetValue);
0436    DNN::evaluateMatrix<Architecture_t>(fResetValue, fRst);
0437 }
0438 
0439  //______________________________________________________________________________
0440 template <typename Architecture_t>
0441 auto inline TBasicGRULayer<Architecture_t>::UpdateGate(const Matrix_t &input, Matrix_t &du)
0442 -> void
0443 {
0444    /*! Computes update gate values according to equation:
0445     *  forget = act(W_input . input + W_state . state + bias)
0446     *  activation function: sigmoid. */
0447    const DNN::EActivationFunction fUpd = this->GetActivationFunctionF1();
0448    Matrix_t tmpState(fUpdateValue.GetNrows(), fUpdateValue.GetNcols());
0449    Architecture_t::MultiplyTranspose(tmpState, fState, fWeightsUpdateGateState);
0450    Architecture_t::MultiplyTranspose(fUpdateValue, input, fWeightsUpdateGate);
0451    Architecture_t::ScaleAdd(fUpdateValue, tmpState);
0452    Architecture_t::AddRowWise(fUpdateValue, fUpdateGateBias);
0453    DNN::evaluateDerivativeMatrix<Architecture_t>(du, fUpd, fUpdateValue);
0454    DNN::evaluateMatrix<Architecture_t>(fUpdateValue, fUpd);
0455 }
0456 
0457  //______________________________________________________________________________
0458 template <typename Architecture_t>
0459 auto inline TBasicGRULayer<Architecture_t>::CandidateValue(const Matrix_t &input, Matrix_t &dc)
0460 -> void
0461 {
0462    /*!
0463         vanilla GRU:
0464         candidate_value = act(W_input . input + W_state . (reset*state) + bias)
0465 
0466         but CuDNN uses reset_after variant that is faster (with bias mode = input)
0467         (apply reset gate multiplication after matrix multiplication)
0468         candidate_value = act(W_input . input + reset * (W_state . state) + bias
0469 
0470         activation function = tanh.
0471 
0472     */
0473 
0474    const DNN::EActivationFunction fCan = this->GetActivationFunctionF2();
0475    Matrix_t tmp(fCandidateValue.GetNrows(), fCandidateValue.GetNcols());
0476    if (!fResetGateAfter) {
0477       Matrix_t tmpState(fResetValue); // I think here tmpState uses fResetValue buffer
0478       Architecture_t::Hadamard(tmpState, fState);
0479       Architecture_t::MultiplyTranspose(tmp, tmpState, fWeightsCandidateState);
0480    } else {
0481       // variant GRU used in cuDNN slightly faster
0482       Architecture_t::MultiplyTranspose(tmp, fState, fWeightsCandidateState);
0483       Architecture_t::Hadamard(tmp, fResetValue);
0484    }
0485    Architecture_t::MultiplyTranspose(fCandidateValue, input, fWeightsCandidate);
0486    Architecture_t::ScaleAdd(fCandidateValue, tmp);
0487    Architecture_t::AddRowWise(fCandidateValue, fCandidateBias);
0488    DNN::evaluateDerivativeMatrix<Architecture_t>(dc, fCan, fCandidateValue);
0489    DNN::evaluateMatrix<Architecture_t>(fCandidateValue, fCan);
0490 }
0491 
0492  //______________________________________________________________________________
0493 template <typename Architecture_t>
0494 auto inline TBasicGRULayer<Architecture_t>::Forward(Tensor_t &input, bool isTraining )
0495 -> void
0496 {
0497    // for Cudnn
0498    if (Architecture_t::IsCudnn()) {
0499 
0500       // input size is stride[1] of input tensor that is B x T x inputSize
0501       assert(input.GetStrides()[1] == this->GetInputSize());
0502 
0503       Tensor_t &x = this->fX;
0504       Tensor_t &y = this->fY;
0505       Architecture_t::Rearrange(x, input);
0506 
0507       //const auto &weights = this->GetWeightsAt(0);
0508       const auto &weights = this->GetWeightsTensor();
0509 
0510       auto &hx = this->fState;
0511       auto &cx = this->fCell;
0512       // use same for hy and cy
0513       auto &hy = this->fState;
0514       auto &cy = this->fCell;
0515 
0516       auto & rnnDesc = static_cast<RNNDescriptors_t &>(*fDescriptors);
0517       auto & rnnWork = static_cast<RNNWorkspace_t &>(*fWorkspace);
0518 
0519       Architecture_t::RNNForward(x, hx, cx, weights, y, hy, cy, rnnDesc, rnnWork, isTraining);
0520 
0521       if (fReturnSequence) {
0522          Architecture_t::Rearrange(this->GetOutput(), y); // swap B and T from y to Output
0523       } else {
0524          // tmp is a reference to y (full cudnn output)
0525          Tensor_t tmp = (y.At(y.GetShape()[0] - 1)).Reshape({y.GetShape()[1], 1, y.GetShape()[2]});
0526          Architecture_t::Copy(this->GetOutput(), tmp);
0527       }
0528 
0529       return;
0530    }
0531 
0532    // D : input size
0533    // H : state size
0534    // T : time size
0535    // B : batch size
0536 
0537    Tensor_t arrInput ( fTimeSteps, this->GetBatchSize(), this->GetInputWidth());
0538    // for (size_t t = 0; t < fTimeSteps; ++t) {
0539    //    arrInput.emplace_back(this->GetBatchSize(), this->GetInputWidth()); // T x B x D
0540    // }
0541    Architecture_t::Rearrange(arrInput, input); // B x T x D
0542 
0543    Tensor_t arrOutput ( fTimeSteps, this->GetBatchSize(), fStateSize );
0544    // for (size_t t = 0; t < fTimeSteps;++t) {
0545    //    arrOutput.emplace_back(this->GetBatchSize(), fStateSize); // T x B x H
0546    // }
0547 
0548    if (!this->fRememberState) {
0549       InitState(DNN::EInitialization::kZero);
0550    }
0551 
0552    /*! Pass each gate values to CellForward() to calculate
0553     *  next hidden state and next cell state. */
0554    for (size_t t = 0; t < fTimeSteps; ++t) {
0555       /* Feed forward network: value of each gate being computed at each timestep t. */
0556       ResetGate(arrInput[t], fDerivativesReset[t]);
0557       Architecture_t::Copy(this->GetResetGateTensorAt(t), fResetValue);
0558       UpdateGate(arrInput[t], fDerivativesUpdate[t]);
0559       Architecture_t::Copy(this->GetUpdateGateTensorAt(t), fUpdateValue);
0560 
0561       CandidateValue(arrInput[t], fDerivativesCandidate[t]);
0562       Architecture_t::Copy(this->GetCandidateGateTensorAt(t), fCandidateValue);
0563 
0564 
0565       CellForward(fUpdateValue, fCandidateValue);
0566 
0567       // Architecture_t::PrintTensor(Tensor_t(fState), "state output");
0568 
0569       Matrix_t arrOutputMt = arrOutput[t];
0570       Architecture_t::Copy(arrOutputMt, fState);
0571    }
0572 
0573    if (fReturnSequence)
0574       Architecture_t::Rearrange(this->GetOutput(), arrOutput); // B x T x D
0575    else {
0576       // get T[end[]]
0577       Tensor_t tmp = arrOutput.At(fTimeSteps - 1); // take last time step
0578       // shape of tmp is  for CPU (column wise) B x D ,   need to reshape to  make a B x D x 1
0579       //  and transpose it to 1 x D x B  (this is how output is expected in columnmajor format)
0580       tmp = tmp.Reshape({tmp.GetShape()[0], tmp.GetShape()[1], 1});
0581       assert(tmp.GetSize() == this->GetOutput().GetSize());
0582       assert(tmp.GetShape()[0] == this->GetOutput().GetShape()[2]); // B is last dim in output and first in tmp
0583       Architecture_t::Rearrange(this->GetOutput(), tmp);
0584       // keep array output
0585       fY = arrOutput;
0586    }
0587 }
0588 
0589 //______________________________________________________________________________
0590 template <typename Architecture_t>
0591 auto inline TBasicGRULayer<Architecture_t>::CellForward(Matrix_t &updateGateValues, Matrix_t &candidateValues)
0592 -> void
0593 {
0594    Architecture_t::Hadamard(fState, updateGateValues);
0595 
0596    // this will reuse content of updateGateValues
0597    Matrix_t tmp(updateGateValues); // H X 1
0598    for (size_t j = 0; j < (size_t) tmp.GetNcols(); j++) {
0599       for (size_t i = 0; i < (size_t) tmp.GetNrows(); i++) {
0600          tmp(i,j) = 1 - tmp(i,j);
0601       }
0602    }
0603 
0604    // Update state
0605    Architecture_t::Hadamard(candidateValues, tmp);
0606    Architecture_t::ScaleAdd(fState, candidateValues);
0607 }
0608 
0609 //____________________________________________________________________________
0610 template <typename Architecture_t>
0611 auto inline TBasicGRULayer<Architecture_t>::Backward(Tensor_t &gradients_backward,           // B x T x D
0612                                                       const Tensor_t &activations_backward)   // B x T x D
0613 -> void
0614 {
0615    // BACKWARD for CUDNN
0616    if (Architecture_t::IsCudnn()) {
0617 
0618       Tensor_t &x = this->fX;
0619       Tensor_t &y = this->fY;
0620       Tensor_t &dx = this->fDx;
0621       Tensor_t &dy = this->fDy;
0622 
0623       // input size is stride[1] of input tensor that is B x T x inputSize
0624       assert(activations_backward.GetStrides()[1] == this->GetInputSize());
0625 
0626 
0627       Architecture_t::Rearrange(x, activations_backward);
0628 
0629       if (!fReturnSequence) {
0630 
0631          // Architecture_t::InitializeZero(dy);
0632          Architecture_t::InitializeZero(dy);
0633 
0634          // Tensor_t tmp1 = y.At(y.GetShape()[0] - 1).Reshape({y.GetShape()[1], 1, y.GetShape()[2]});
0635          Tensor_t tmp2 = dy.At(dy.GetShape()[0] - 1).Reshape({dy.GetShape()[1], 1, dy.GetShape()[2]});
0636 
0637          // Architecture_t::Copy(tmp1, this->GetOutput());
0638          Architecture_t::Copy(tmp2, this->GetActivationGradients());
0639       } else {
0640          Architecture_t::Rearrange(y, this->GetOutput());
0641          Architecture_t::Rearrange(dy, this->GetActivationGradients());
0642       }
0643 
0644       // Architecture_t::PrintTensor(this->GetOutput(), "output before bwd");
0645 
0646       // for cudnn Matrix_t and Tensor_t are same type
0647       const auto &weights = this->GetWeightsTensor();
0648       auto &weightGradients = this->GetWeightGradientsTensor();
0649 
0650       // note that cudnnRNNBackwardWeights accumulate the weight gradients.
0651       // We need then to initialize the tensor to zero every time
0652       Architecture_t::InitializeZero(weightGradients);
0653 
0654       // hx is fState
0655       auto &hx = this->GetState();
0656       auto &cx = this->GetCell();
0657       // use same for hy and cy
0658       auto &dhy = hx;
0659       auto &dcy = cx;
0660       auto &dhx = hx;
0661       auto &dcx = cx;
0662 
0663       auto & rnnDesc = static_cast<RNNDescriptors_t &>(*fDescriptors);
0664       auto & rnnWork = static_cast<RNNWorkspace_t &>(*fWorkspace);
0665 
0666       Architecture_t::RNNBackward(x, hx, cx, y, dy, dhy, dcy, weights, dx, dhx, dcx, weightGradients, rnnDesc, rnnWork);
0667 
0668       // Architecture_t::PrintTensor(this->GetOutput(), "output after bwd");
0669 
0670       if (gradients_backward.GetSize() != 0)
0671          Architecture_t::Rearrange(gradients_backward, dx);
0672 
0673       return;
0674    }
0675 
0676    // gradients_backward is activationGradients of layer before it, which is input layer.
0677    // Currently, gradients_backward is for input(x) and not for state.
0678    // For the state it can be:
0679    Matrix_t state_gradients_backward(this->GetBatchSize(), fStateSize); // B x H
0680    DNN::initialize<Architecture_t>(state_gradients_backward, DNN::EInitialization::kZero); // B x H
0681 
0682    // if dummy is false gradients_backward will be written back on the matrix
0683    bool dummy = false;
0684    if (gradients_backward.GetSize() == 0 || gradients_backward[0].GetNrows() == 0 || gradients_backward[0].GetNcols() == 0) {
0685       dummy = true;
0686    }
0687 
0688    Tensor_t arr_gradients_backward ( fTimeSteps, this->GetBatchSize(), this->GetInputSize());
0689 
0690 
0691    //Architecture_t::Rearrange(arr_gradients_backward, gradients_backward); // B x T x D
0692    // activations_backward is input.
0693    Tensor_t arr_activations_backward ( fTimeSteps, this->GetBatchSize(), this->GetInputSize());
0694 
0695    Architecture_t::Rearrange(arr_activations_backward, activations_backward); // B x T x D
0696 
0697    /*! For backpropagation, we need to calculate loss. For loss, output must be known.
0698     *  We obtain outputs during forward propagation and place the results in arr_output tensor. */
0699    Tensor_t arr_output ( fTimeSteps, this->GetBatchSize(), fStateSize);
0700 
0701    Matrix_t initState(this->GetBatchSize(), fStateSize); // B x H
0702    DNN::initialize<Architecture_t>(initState, DNN::EInitialization::kZero); // B x H
0703 
0704    // This will take partial derivative of state[t] w.r.t state[t-1]
0705    Tensor_t arr_actgradients ( fTimeSteps, this->GetBatchSize(), fStateSize);
0706 
0707    if (fReturnSequence) {
0708       Architecture_t::Rearrange(arr_output, this->GetOutput());
0709       Architecture_t::Rearrange(arr_actgradients, this->GetActivationGradients());
0710    } else {
0711       //
0712       arr_output = fY;
0713       Architecture_t::InitializeZero(arr_actgradients);
0714       // need to reshape to pad a time dimension = 1 (note here is columnmajor tensors)
0715       Tensor_t tmp_grad = arr_actgradients.At(fTimeSteps - 1).Reshape({this->GetBatchSize(), fStateSize, 1});
0716       assert(tmp_grad.GetSize() == this->GetActivationGradients().GetSize());
0717       assert(tmp_grad.GetShape()[0] ==
0718              this->GetActivationGradients().GetShape()[2]); // B in tmp is [0] and [2] in input act. gradients
0719 
0720       Architecture_t::Rearrange(tmp_grad, this->GetActivationGradients());
0721    }
0722 
0723    /*! There are total 8 different weight matrices and 4 bias vectors.
0724     *  Re-initialize them with zero because it should have some value. (can't be garbage values) */
0725 
0726    // Reset Gate.
0727    fWeightsResetGradients.Zero();
0728    fWeightsResetStateGradients.Zero();
0729    fResetBiasGradients.Zero();
0730 
0731    // Update Gate.
0732    fWeightsUpdateGradients.Zero();
0733    fWeightsUpdateStateGradients.Zero();
0734    fUpdateBiasGradients.Zero();
0735 
0736    // Candidate Gate.
0737    fWeightsCandidateGradients.Zero();
0738    fWeightsCandidateStateGradients.Zero();
0739    fCandidateBiasGradients.Zero();
0740 
0741 
0742    for (size_t t = fTimeSteps; t > 0; t--) {
0743       // Store the sum of gradients obtained at each timestep during backward pass.
0744       Architecture_t::ScaleAdd(state_gradients_backward, arr_actgradients[t-1]);
0745       if (t > 1) {
0746          const Matrix_t &prevStateActivations = arr_output[t-2];
0747          Matrix_t dx = arr_gradients_backward[t-1];
0748          // During forward propagation, each gate value calculates their gradients.
0749          CellBackward(state_gradients_backward, prevStateActivations,
0750                       this->GetResetGateTensorAt(t-1), this->GetUpdateGateTensorAt(t-1),
0751                       this->GetCandidateGateTensorAt(t-1),
0752                       arr_activations_backward[t-1], dx ,
0753                       fDerivativesReset[t-1], fDerivativesUpdate[t-1],
0754                       fDerivativesCandidate[t-1]);
0755       } else {
0756          const Matrix_t &prevStateActivations = initState;
0757          Matrix_t dx = arr_gradients_backward[t-1];
0758          CellBackward(state_gradients_backward, prevStateActivations,
0759                       this->GetResetGateTensorAt(t-1), this->GetUpdateGateTensorAt(t-1),
0760                       this->GetCandidateGateTensorAt(t-1),
0761                       arr_activations_backward[t-1], dx ,
0762                       fDerivativesReset[t-1], fDerivativesUpdate[t-1],
0763                       fDerivativesCandidate[t-1]);
0764         }
0765    }
0766 
0767    if (!dummy) {
0768       Architecture_t::Rearrange(gradients_backward, arr_gradients_backward );
0769    }
0770 
0771 }
0772 
0773 
0774 //______________________________________________________________________________
0775 template <typename Architecture_t>
0776 auto inline TBasicGRULayer<Architecture_t>::CellBackward(Matrix_t & state_gradients_backward,
0777                                                           const Matrix_t & precStateActivations,
0778                                                           const Matrix_t & reset_gate, const Matrix_t & update_gate,
0779                                                           const Matrix_t & candidate_gate,
0780                                                           const Matrix_t & input, Matrix_t & input_gradient,
0781                                                           Matrix_t &dr, Matrix_t &du, Matrix_t &dc)
0782 -> Matrix_t &
0783 {
0784    /*! Call here GRULayerBackward() to pass parameters i.e. gradient
0785     *  values obtained from each gate during forward propagation. */
0786    return Architecture_t::GRULayerBackward(state_gradients_backward,
0787                                            fWeightsResetGradients, fWeightsUpdateGradients, fWeightsCandidateGradients,
0788                                            fWeightsResetStateGradients, fWeightsUpdateStateGradients,
0789                                            fWeightsCandidateStateGradients, fResetBiasGradients, fUpdateBiasGradients,
0790                                            fCandidateBiasGradients, dr, du, dc,
0791                                            precStateActivations,
0792                                            reset_gate, update_gate, candidate_gate,
0793                                            fWeightsResetGate, fWeightsUpdateGate, fWeightsCandidate,
0794                                            fWeightsResetGateState, fWeightsUpdateGateState, fWeightsCandidateState,
0795                                            input, input_gradient, fResetGateAfter);
0796 }
0797 
0798 
0799 //______________________________________________________________________________
0800 template <typename Architecture_t>
0801 auto TBasicGRULayer<Architecture_t>::InitState(DNN::EInitialization /* m */)
0802 -> void
0803 {
0804    DNN::initialize<Architecture_t>(this->GetState(),  DNN::EInitialization::kZero);
0805 }
0806 
0807  //______________________________________________________________________________
0808 template<typename Architecture_t>
0809 auto TBasicGRULayer<Architecture_t>::Print() const
0810 -> void
0811 {
0812    std::cout << " GRU Layer: \t ";
0813    std::cout << " (NInput = " << this->GetInputSize();  // input size
0814    std::cout << ", NState = " << this->GetStateSize();  // hidden state size
0815    std::cout << ", NTime  = " << this->GetTimeSteps() << " )";  // time size
0816    std::cout << "\tOutput = ( " << this->GetOutput().GetFirstSize() << " , " << this->GetOutput()[0].GetNrows() << " , " << this->GetOutput()[0].GetNcols() << " )\n";
0817 }
0818 
0819 //______________________________________________________________________________
0820 template <typename Architecture_t>
0821 auto inline TBasicGRULayer<Architecture_t>::AddWeightsXMLTo(void *parent)
0822 -> void
0823 {
0824    auto layerxml = gTools().xmlengine().NewChild(parent, nullptr, "GRULayer");
0825 
0826    // Write all other info like outputSize, cellSize, inputSize, timeSteps, rememberState
0827    gTools().xmlengine().NewAttr(layerxml, nullptr, "StateSize", gTools().StringFromInt(this->GetStateSize()));
0828    gTools().xmlengine().NewAttr(layerxml, nullptr, "InputSize", gTools().StringFromInt(this->GetInputSize()));
0829    gTools().xmlengine().NewAttr(layerxml, nullptr, "TimeSteps", gTools().StringFromInt(this->GetTimeSteps()));
0830    gTools().xmlengine().NewAttr(layerxml, nullptr, "RememberState", gTools().StringFromInt(this->DoesRememberState()));
0831    gTools().xmlengine().NewAttr(layerxml, nullptr, "ReturnSequence", gTools().StringFromInt(this->DoesReturnSequence()));
0832    gTools().xmlengine().NewAttr(layerxml, nullptr, "ResetGateAfter", gTools().StringFromInt(this->fResetGateAfter));
0833 
0834    // write weights and bias matrices
0835    this->WriteMatrixToXML(layerxml, "ResetWeights", this->GetWeightsAt(0));
0836    this->WriteMatrixToXML(layerxml, "ResetStateWeights", this->GetWeightsAt(1));
0837    this->WriteMatrixToXML(layerxml, "ResetBiases", this->GetBiasesAt(0));
0838    this->WriteMatrixToXML(layerxml, "UpdateWeights", this->GetWeightsAt(2));
0839    this->WriteMatrixToXML(layerxml, "UpdateStateWeights", this->GetWeightsAt(3));
0840    this->WriteMatrixToXML(layerxml, "UpdateBiases", this->GetBiasesAt(1));
0841    this->WriteMatrixToXML(layerxml, "CandidateWeights", this->GetWeightsAt(4));
0842    this->WriteMatrixToXML(layerxml, "CandidateStateWeights", this->GetWeightsAt(5));
0843    this->WriteMatrixToXML(layerxml, "CandidateBiases", this->GetBiasesAt(2));
0844 }
0845 
0846  //______________________________________________________________________________
0847 template <typename Architecture_t>
0848 auto inline TBasicGRULayer<Architecture_t>::ReadWeightsFromXML(void *parent)
0849 -> void
0850 {
0851     // Read weights and biases
0852    this->ReadMatrixXML(parent, "ResetWeights", this->GetWeightsAt(0));
0853    this->ReadMatrixXML(parent, "ResetStateWeights", this->GetWeightsAt(1));
0854    this->ReadMatrixXML(parent, "ResetBiases", this->GetBiasesAt(0));
0855    this->ReadMatrixXML(parent, "UpdateWeights", this->GetWeightsAt(2));
0856    this->ReadMatrixXML(parent, "UpdateStateWeights", this->GetWeightsAt(3));
0857    this->ReadMatrixXML(parent, "UpdateBiases", this->GetBiasesAt(1));
0858    this->ReadMatrixXML(parent, "CandidateWeights", this->GetWeightsAt(4));
0859    this->ReadMatrixXML(parent, "CandidateStateWeights", this->GetWeightsAt(5));
0860    this->ReadMatrixXML(parent, "CandidateBiases", this->GetBiasesAt(2));
0861 }
0862 
0863 } // namespace GRU
0864 } // namespace DNN
0865 } // namespace TMVA
0866 
0867 #endif // GRU_LAYER_H