Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-18 10:10:54

0001 // @(#)root/tmva/tmva/dnn/lstm:$Id$
0002 // Author: Surya S Dwivedi 27/05/19
0003 
0004 /**********************************************************************************
0005  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis       *
0006  * Package: TMVA                                                                  *
0007  * Class : BasicLSTMLayer                                                         *
0008  *                                                                                *
0009  * Description:                                                                   *
0010  *       NeuralNetwork                                                            *
0011  *                                                                                *
0012  * Authors (alphabetical):                                                        *
0013  *       Surya S Dwivedi  <surya2191997@gmail.com> - IIT Kharagpur, India         *
0014  *                                                                                *
0015  * Copyright (c) 2005-2019:                                                       *
0016  * All rights reserved.                                                           *
0017  *       CERN, Switzerland                                                        *
0018  *                                                                                *
0019  * For the licensing terms see $ROOTSYS/LICENSE.                                  *
0020  * For the list of contributors see $ROOTSYS/README/CREDITS.                      *
0021  **********************************************************************************/
0022 
0023 //#pragma once
0024 
0025 //////////////////////////////////////////////////////////////////////
0026 // This class implements the LSTM layer. LSTM is a variant of vanilla
0027 // RNN which is capable of learning long range dependencies.
0028 //////////////////////////////////////////////////////////////////////
0029 
0030 #ifndef TMVA_DNN_LSTM_LAYER
0031 #define TMVA_DNN_LSTM_LAYER
0032 
0033 #include <cmath>
0034 #include <iostream>
0035 #include <vector>
0036 
0037 #include "TMatrix.h"
0038 #include "TMVA/DNN/Functions.h"
0039 
0040 namespace TMVA
0041 {
0042 namespace DNN
0043 {
0044 namespace RNN
0045 {
0046 
0047 //______________________________________________________________________________
0048 //
0049 // Basic LSTM Layer
0050 //______________________________________________________________________________
0051 
0052 /** \class BasicLSTMLayer
0053       Generic implementation
0054 */
0055 template<typename Architecture_t>
0056       class TBasicLSTMLayer : public VGeneralLayer<Architecture_t>
0057 {
0058 
0059 public:
0060 
0061    using Matrix_t = typename Architecture_t::Matrix_t;
0062    using Scalar_t = typename Architecture_t::Scalar_t;
0063    using Tensor_t = typename Architecture_t::Tensor_t;
0064 
0065    using LayerDescriptor_t = typename Architecture_t::RecurrentDescriptor_t;
0066    using WeightsDescriptor_t = typename Architecture_t::FilterDescriptor_t;
0067    using TensorDescriptor_t = typename Architecture_t::TensorDescriptor_t;
0068    using HelperDescriptor_t = typename Architecture_t::DropoutDescriptor_t;
0069 
0070    using RNNWorkspace_t = typename Architecture_t::RNNWorkspace_t;
0071    using RNNDescriptors_t = typename Architecture_t::RNNDescriptors_t;
0072 
0073 private:
0074 
0075    size_t fStateSize;                           ///< Hidden state size for LSTM
0076    size_t fCellSize;                            ///< Cell state size of LSTM
0077    size_t fTimeSteps;                           ///< Timesteps for LSTM
0078 
0079    bool fRememberState;                         ///< Remember state in next pass
0080    bool fReturnSequence = false;                ///< Return in output full sequence or just last element
0081 
0082    DNN::EActivationFunction fF1;                ///< Activation function: sigmoid
0083    DNN::EActivationFunction fF2;                ///< Activation function: tanh
0084 
0085    Matrix_t fInputValue;                        ///< Computed input gate values
0086    Matrix_t fCandidateValue;                    ///< Computed candidate values
0087    Matrix_t fForgetValue;                       ///< Computed forget gate values
0088    Matrix_t fOutputValue;                       ///< Computed output gate values
0089    Matrix_t fState;                             ///< Hidden state of LSTM
0090    Matrix_t fCell;                              ///< Cell state of LSTM
0091 
0092    Matrix_t &fWeightsInputGate;                 ///< Input Gate weights for input, fWeights[0]
0093    Matrix_t &fWeightsInputGateState;            ///< Input Gate weights for prev state, fWeights[1]
0094    Matrix_t &fInputGateBias;                    ///< Input Gate bias
0095 
0096    Matrix_t &fWeightsForgetGate;                ///< Forget Gate weights for input, fWeights[2]
0097    Matrix_t &fWeightsForgetGateState;           ///< Forget Gate weights for prev state, fWeights[3]
0098    Matrix_t &fForgetGateBias;                   ///< Forget Gate bias
0099 
0100    Matrix_t &fWeightsCandidate;                 ///< Candidate Gate weights for input, fWeights[4]
0101    Matrix_t &fWeightsCandidateState;            ///< Candidate Gate weights for prev state, fWeights[5]
0102    Matrix_t &fCandidateBias;                    ///< Candidate Gate bias
0103 
0104    Matrix_t &fWeightsOutputGate;                ///< Output Gate weights for input, fWeights[6]
0105    Matrix_t &fWeightsOutputGateState;           ///< Output Gate weights for prev state, fWeights[7]
0106    Matrix_t &fOutputGateBias;                   ///< Output Gate bias
0107 
0108    std::vector<Matrix_t> input_gate_value;      ///< input gate value for every time step
0109    std::vector<Matrix_t> forget_gate_value;     ///< forget gate value for every time step
0110    std::vector<Matrix_t> candidate_gate_value;  ///< candidate gate value for every time step
0111    std::vector<Matrix_t> output_gate_value;     ///< output gate value for every time step
0112    std::vector<Matrix_t> cell_value;            ///< cell value for every time step
0113    std::vector<Matrix_t> fDerivativesInput;     ///< First fDerivatives of the activations input gate
0114    std::vector<Matrix_t> fDerivativesForget;    ///< First fDerivatives of the activations forget gate
0115    std::vector<Matrix_t> fDerivativesCandidate; ///< First fDerivatives of the activations candidate gate
0116    std::vector<Matrix_t> fDerivativesOutput;    ///< First fDerivatives of the activations output gate
0117 
0118    Matrix_t &fWeightsInputGradients;            ///< Gradients w.r.t the input gate - input weights
0119    Matrix_t &fWeightsInputStateGradients;       ///< Gradients w.r.t the input gate - hidden state weights
0120    Matrix_t &fInputBiasGradients;               ///< Gradients w.r.t the input gate - bias weights
0121    Matrix_t &fWeightsForgetGradients;           ///< Gradients w.r.t the forget gate - input weights
0122    Matrix_t &fWeightsForgetStateGradients;      ///< Gradients w.r.t the forget gate - hidden state weights
0123    Matrix_t &fForgetBiasGradients;              ///< Gradients w.r.t the forget gate - bias weights
0124    Matrix_t &fWeightsCandidateGradients;        ///< Gradients w.r.t the candidate gate - input weights
0125    Matrix_t &fWeightsCandidateStateGradients;   ///< Gradients w.r.t the candidate gate - hidden state weights
0126    Matrix_t &fCandidateBiasGradients;           ///< Gradients w.r.t the candidate gate - bias weights
0127    Matrix_t &fWeightsOutputGradients;           ///< Gradients w.r.t the output gate - input weights
0128    Matrix_t &fWeightsOutputStateGradients;      ///< Gradients w.r.t the output gate - hidden state weights
0129    Matrix_t &fOutputBiasGradients;              ///< Gradients w.r.t the output gate - bias weights
0130 
0131    // Tensor representing all weights (used by cuDNN)
0132    Tensor_t fWeightsTensor;                     ///< Tensor for all weights
0133    Tensor_t fWeightGradientsTensor;             ///< Tensor for all weight gradients
0134 
0135    // tensors used internally for the forward and backward pass
0136    Tensor_t fX;  ///<  cached input tensor as T x B x I
0137    Tensor_t fY;  ///<  cached output tensor as T x B x S
0138    Tensor_t fDx; ///< cached   gradient on the input (output of backward)   as T x B x I
0139    Tensor_t fDy; ///< cached  activation gradient (input of backward)   as T x B x S
0140 
0141    TDescriptors *fDescriptors = nullptr; ///< Keeps all the RNN descriptors
0142    TWorkspace *fWorkspace = nullptr;     // workspace needed for GPU computation (CudNN)
0143 
0144 public:
0145 
0146    /*! Constructor */
0147    TBasicLSTMLayer(size_t batchSize, size_t stateSize, size_t inputSize, size_t timeSteps, bool rememberState = false,
0148                    bool returnSequence = false,
0149                    DNN::EActivationFunction f1 = DNN::EActivationFunction::kSigmoid,
0150                    DNN::EActivationFunction f2 = DNN::EActivationFunction::kTanh, bool training = true,
0151                    DNN::EInitialization fA = DNN::EInitialization::kZero);
0152 
0153    /*! Copy Constructor */
0154    TBasicLSTMLayer(const TBasicLSTMLayer &);
0155 
0156    /*! Initialize the weights according to the given initialization
0157     **  method. */
0158    virtual void Initialize();
0159 
0160    /*! Initialize the hidden state and cell state method. */
0161    void InitState(DNN::EInitialization m = DNN::EInitialization::kZero);
0162 
0163    /*! Computes the next hidden state
0164     *  and next cell state with given input matrix. */
0165    void Forward(Tensor_t &input, bool isTraining = true);
0166 
0167    /*! Forward for a single cell (time unit) */
0168    void CellForward(Matrix_t &inputGateValues, const Matrix_t &forgetGateValues,
0169                   const Matrix_t &candidateValues, const Matrix_t &outputGateValues);
0170 
0171    /*! Backpropagates the error. Must only be called directly at the corresponding
0172     *  call to Forward(...). */
0173    void Backward(Tensor_t &gradients_backward,
0174                  const Tensor_t &activations_backward);
0175 
0176    /* Updates weights and biases, given the learning rate */
0177    void Update(const Scalar_t learningRate);
0178 
0179    /*! Backward for a single time unit
0180     *  a the corresponding call to Forward(...). */
0181    Matrix_t & CellBackward(Matrix_t & state_gradients_backward,
0182                            Matrix_t & cell_gradients_backward,
0183                            const Matrix_t & precStateActivations, const Matrix_t & precCellActivations,
0184                            const Matrix_t & input_gate, const Matrix_t & forget_gate,
0185                            const Matrix_t & candidate_gate, const Matrix_t & output_gate,
0186                            const Matrix_t & input, Matrix_t & input_gradient,
0187                            Matrix_t &di, Matrix_t &df, Matrix_t &dc, Matrix_t &dout, size_t t);
0188 
0189    /*! Decides the values we'll update (NN with Sigmoid) */
0190    void InputGate(const Matrix_t &input, Matrix_t &di);
0191 
0192    /*! Forgets the past values (NN with Sigmoid) */
0193    void ForgetGate(const Matrix_t &input, Matrix_t &df);
0194 
0195    /*! Decides the new candidate values (NN with Tanh) */
0196    void CandidateValue(const Matrix_t &input, Matrix_t &dc);
0197 
0198    /*! Computes output values (NN with Sigmoid) */
0199    void OutputGate(const Matrix_t &input, Matrix_t &dout);
0200 
0201    /*! Prints the info about the layer */
0202    void Print() const;
0203 
0204    /*! Writes the information and the weights about the layer in an XML node. */
0205    void AddWeightsXMLTo(void *parent);
0206 
0207    /*! Read the information and the weights about the layer from XML node. */
0208    void ReadWeightsFromXML(void *parent);
0209 
0210    /*! Getters */
0211    size_t GetInputSize()               const { return this->GetInputWidth(); }
0212    size_t GetTimeSteps()               const { return fTimeSteps; }
0213    size_t GetStateSize()               const { return fStateSize; }
0214    size_t GetCellSize()                const { return fCellSize; }
0215 
0216    inline bool DoesRememberState()       const { return fRememberState; }
0217    inline bool DoesReturnSequence()      const { return fReturnSequence; }
0218 
0219    inline DNN::EActivationFunction     GetActivationFunctionF1()        const { return fF1; }
0220    inline DNN::EActivationFunction     GetActivationFunctionF2()        const { return fF2; }
0221 
0222    const Matrix_t                    & GetInputGateValue()                const { return fInputValue; }
0223    Matrix_t                          & GetInputGateValue()                      { return fInputValue; }
0224    const Matrix_t                    & GetCandidateValue()                const { return fCandidateValue; }
0225    Matrix_t                          & GetCandidateValue()                      { return fCandidateValue; }
0226    const Matrix_t                    & GetForgetGateValue()               const { return fForgetValue; }
0227    Matrix_t                          & GetForgetGateValue()                     { return fForgetValue; }
0228    const Matrix_t                    & GetOutputGateValue()               const { return fOutputValue; }
0229    Matrix_t                          & GetOutputGateValue()                     { return fOutputValue; }
0230 
0231    const Matrix_t                    & GetState()                   const { return fState; }
0232    Matrix_t                          & GetState()                         { return fState; }
0233    const Matrix_t                    & GetCell()                    const { return fCell; }
0234    Matrix_t                          & GetCell()                          { return fCell; }
0235 
0236    const Matrix_t                    & GetWeightsInputGate()              const { return fWeightsInputGate; }
0237    Matrix_t                          & GetWeightsInputGate()                    { return fWeightsInputGate; }
0238    const Matrix_t                    & GetWeightsCandidate()              const { return fWeightsCandidate; }
0239    Matrix_t                          & GetWeightsCandidate()                    { return fWeightsCandidate; }
0240    const Matrix_t                    & GetWeightsForgetGate()             const { return fWeightsForgetGate; }
0241    Matrix_t                          & GetWeightsForgetGate()                   { return fWeightsForgetGate; }
0242    const Matrix_t                    & GetWeightsOutputGate()             const { return fWeightsOutputGate; }
0243    Matrix_t                          & GetWeightsOutputGate()                   { return fWeightsOutputGate; }
0244    const Matrix_t                    & GetWeightsInputGateState()         const { return fWeightsInputGateState; }
0245    Matrix_t                          & GetWeightsInputGateState()               { return fWeightsInputGateState; }
0246    const Matrix_t                    & GetWeightsForgetGateState()        const { return fWeightsForgetGateState; }
0247    Matrix_t                          & GetWeightsForgetGateState()              { return fWeightsForgetGateState; }
0248    const Matrix_t                    & GetWeightsCandidateState()         const { return fWeightsCandidateState; }
0249    Matrix_t                          & GetWeightsCandidateState()               { return fWeightsCandidateState; }
0250    const Matrix_t                    & GetWeightsOutputGateState()        const { return fWeightsOutputGateState; }
0251    Matrix_t                          & GetWeightsOutputGateState()              { return fWeightsOutputGateState; }
0252 
0253    const std::vector<Matrix_t>       & GetDerivativesInput()              const { return fDerivativesInput; }
0254    std::vector<Matrix_t>             & GetDerivativesInput()                    { return fDerivativesInput; }
0255    const Matrix_t                    & GetInputDerivativesAt(size_t i)    const { return fDerivativesInput[i]; }
0256    Matrix_t                          & GetInputDerivativesAt(size_t i)           { return fDerivativesInput[i]; }
0257    const std::vector<Matrix_t>       & GetDerivativesForget()              const { return fDerivativesForget; }
0258    std::vector<Matrix_t>             & GetDerivativesForget()                    { return fDerivativesForget; }
0259    const Matrix_t                    & GetForgetDerivativesAt(size_t i)    const { return fDerivativesForget[i]; }
0260    Matrix_t                          & GetForgetDerivativesAt(size_t i)          { return fDerivativesForget[i]; }
0261    const std::vector<Matrix_t>       & GetDerivativesCandidate()           const { return fDerivativesCandidate; }
0262    std::vector<Matrix_t>             & GetDerivativesCandidate()                 { return fDerivativesCandidate; }
0263    const Matrix_t                    & GetCandidateDerivativesAt(size_t i) const { return fDerivativesCandidate[i]; }
0264    Matrix_t                          & GetCandidateDerivativesAt(size_t i)       { return fDerivativesCandidate[i]; }
0265    const std::vector<Matrix_t>       & GetDerivativesOutput()              const { return fDerivativesOutput; }
0266    std::vector<Matrix_t>             & GetDerivativesOutput()                    { return fDerivativesOutput; }
0267    const Matrix_t                    & GetOutputDerivativesAt(size_t i)    const { return fDerivativesOutput[i]; }
0268    Matrix_t                          & GetOutputDerivativesAt(size_t i)          { return fDerivativesOutput[i]; }
0269 
0270    const std::vector<Matrix_t>       & GetInputGateTensor()              const { return input_gate_value; }
0271    std::vector<Matrix_t>             & GetInputGateTensor()                    { return input_gate_value; }
0272    const Matrix_t                    & GetInputGateTensorAt(size_t i)    const { return input_gate_value[i]; }
0273    Matrix_t                          & GetInputGateTensorAt(size_t i)           { return input_gate_value[i]; }
0274    const std::vector<Matrix_t>       & GetForgetGateTensor()              const { return forget_gate_value; }
0275    std::vector<Matrix_t>             & GetForgetGateTensor()                    { return forget_gate_value; }
0276    const Matrix_t                    & GetForgetGateTensorAt(size_t i)    const { return forget_gate_value[i]; }
0277    Matrix_t                          & GetForgetGateTensorAt(size_t i)          { return forget_gate_value[i]; }
0278    const std::vector<Matrix_t>       & GetCandidateGateTensor()           const { return candidate_gate_value; }
0279    std::vector<Matrix_t>             & GetCandidateGateTensor()                 { return candidate_gate_value; }
0280    const Matrix_t                    & GetCandidateGateTensorAt(size_t i) const { return candidate_gate_value[i]; }
0281    Matrix_t                          & GetCandidateGateTensorAt(size_t i)       { return candidate_gate_value[i]; }
0282    const std::vector<Matrix_t>       & GetOutputGateTensor()              const { return output_gate_value; }
0283    std::vector<Matrix_t>             & GetOutputGateTensor()                    { return output_gate_value; }
0284    const Matrix_t                    & GetOutputGateTensorAt(size_t i)    const { return output_gate_value[i]; }
0285    Matrix_t                          & GetOutputGateTensorAt(size_t i)          { return output_gate_value[i]; }
0286    const std::vector<Matrix_t>       & GetCellTensor()                    const { return cell_value; }
0287    std::vector<Matrix_t>             & GetCellTensor()                          { return cell_value; }
0288    const Matrix_t                    & GetCellTensorAt(size_t i)          const { return cell_value[i]; }
0289    Matrix_t                          & GetCellTensorAt(size_t i)                { return cell_value[i]; }
0290 
0291    const Matrix_t                   & GetInputGateBias()         const { return fInputGateBias; }
0292    Matrix_t                         & GetInputGateBias()               { return fInputGateBias; }
0293    const Matrix_t                   & GetForgetGateBias()        const { return fForgetGateBias; }
0294    Matrix_t                         & GetForgetGateBias()              { return fForgetGateBias; }
0295    const Matrix_t                   & GetCandidateBias()         const { return fCandidateBias; }
0296    Matrix_t                         & GetCandidateBias()               { return fCandidateBias; }
0297    const Matrix_t                   & GetOutputGateBias()        const { return fOutputGateBias; }
0298    Matrix_t                         & GetOutputGateBias()              { return fOutputGateBias; }
0299    const Matrix_t                   & GetWeightsInputGradients()        const { return fWeightsInputGradients; }
0300    Matrix_t                         & GetWeightsInputGradients()              { return fWeightsInputGradients; }
0301    const Matrix_t                   & GetWeightsInputStateGradients()   const { return fWeightsInputStateGradients; }
0302    Matrix_t                         & GetWeightsInputStateGradients()         { return fWeightsInputStateGradients; }
0303    const Matrix_t                   & GetInputBiasGradients()           const { return fInputBiasGradients; }
0304    Matrix_t                         & GetInputBiasGradients()                 { return fInputBiasGradients; }
0305    const Matrix_t                   & GetWeightsForgetGradients()      const { return fWeightsForgetGradients; }
0306    Matrix_t                         & GetWeightsForgetGradients()            { return fWeightsForgetGradients; }
0307    const Matrix_t                   & GetWeigthsForgetStateGradients() const { return fWeightsForgetStateGradients; }
0308    Matrix_t                         & GetWeightsForgetStateGradients()       { return fWeightsForgetStateGradients; }
0309    const Matrix_t                   & GetForgetBiasGradients()         const { return fForgetBiasGradients; }
0310    Matrix_t                         & GetForgetBiasGradients()               { return fForgetBiasGradients; }
0311    const Matrix_t                   & GetWeightsCandidateGradients()      const { return fWeightsCandidateGradients; }
0312    Matrix_t                         & GetWeightsCandidateGradients()            { return fWeightsCandidateGradients; }
0313    const Matrix_t                   & GetWeightsCandidateStateGradients() const { return fWeightsCandidateStateGradients; }
0314    Matrix_t                         & GetWeightsCandidateStateGradients()       { return fWeightsCandidateStateGradients; }
0315    const Matrix_t                   & GetCandidateBiasGradients()         const { return fCandidateBiasGradients; }
0316    Matrix_t                         & GetCandidateBiasGradients()               { return fCandidateBiasGradients; }
0317    const Matrix_t                   & GetWeightsOutputGradients()        const { return fWeightsOutputGradients; }
0318    Matrix_t                         & GetWeightsOutputGradients()              { return fWeightsOutputGradients; }
0319    const Matrix_t                   & GetWeightsOutputStateGradients()   const { return fWeightsOutputStateGradients; }
0320    Matrix_t                         & GetWeightsOutputStateGradients()         { return fWeightsOutputStateGradients; }
0321    const Matrix_t                   & GetOutputBiasGradients()           const { return fOutputBiasGradients; }
0322    Matrix_t                         & GetOutputBiasGradients()                 { return fOutputBiasGradients; }
0323 
0324    Tensor_t &GetWeightsTensor() { return fWeightsTensor; }
0325    const Tensor_t &GetWeightsTensor() const { return fWeightsTensor; }
0326    Tensor_t &GetWeightGradientsTensor() { return fWeightGradientsTensor; }
0327    const Tensor_t &GetWeightGradientsTensor() const { return fWeightGradientsTensor; }
0328 
0329    Tensor_t &GetX() { return fX; }
0330    Tensor_t &GetY() { return fY; }
0331    Tensor_t &GetDX() { return fDx; }
0332    Tensor_t &GetDY() { return fDy; }
0333 };
0334 
0335 //______________________________________________________________________________
0336 //
0337 // Basic LSTM-Layer Implementation
0338 //______________________________________________________________________________
0339 
0340 template <typename Architecture_t>
0341 TBasicLSTMLayer<Architecture_t>::TBasicLSTMLayer(size_t batchSize, size_t stateSize, size_t inputSize, size_t timeSteps,
0342                                                  bool rememberState, bool returnSequence, DNN::EActivationFunction f1,
0343                                                  DNN::EActivationFunction f2, bool /* training */,
0344                                                  DNN::EInitialization fA)
0345    : VGeneralLayer<Architecture_t>(
0346         batchSize, 1, timeSteps, inputSize, 1, (returnSequence) ? timeSteps : 1, stateSize, 8,
0347         {stateSize, stateSize, stateSize, stateSize, stateSize, stateSize, stateSize, stateSize},
0348         {inputSize, inputSize, inputSize, inputSize, stateSize, stateSize, stateSize, stateSize}, 4,
0349         {stateSize, stateSize, stateSize, stateSize}, {1, 1, 1, 1}, batchSize, (returnSequence) ? timeSteps : 1,
0350         stateSize, fA),
0351      fStateSize(stateSize), fCellSize(stateSize), fTimeSteps(timeSteps), fRememberState(rememberState),
0352      fReturnSequence(returnSequence), fF1(f1), fF2(f2), fInputValue(batchSize, stateSize),
0353      fCandidateValue(batchSize, stateSize), fForgetValue(batchSize, stateSize), fOutputValue(batchSize, stateSize),
0354      fState(batchSize, stateSize), fCell(batchSize, stateSize), fWeightsInputGate(this->GetWeightsAt(0)),
0355      fWeightsInputGateState(this->GetWeightsAt(4)), fInputGateBias(this->GetBiasesAt(0)),
0356      fWeightsForgetGate(this->GetWeightsAt(1)), fWeightsForgetGateState(this->GetWeightsAt(5)),
0357      fForgetGateBias(this->GetBiasesAt(1)), fWeightsCandidate(this->GetWeightsAt(2)),
0358      fWeightsCandidateState(this->GetWeightsAt(6)), fCandidateBias(this->GetBiasesAt(2)),
0359      fWeightsOutputGate(this->GetWeightsAt(3)), fWeightsOutputGateState(this->GetWeightsAt(7)),
0360      fOutputGateBias(this->GetBiasesAt(3)), fWeightsInputGradients(this->GetWeightGradientsAt(0)),
0361      fWeightsInputStateGradients(this->GetWeightGradientsAt(4)), fInputBiasGradients(this->GetBiasGradientsAt(0)),
0362      fWeightsForgetGradients(this->GetWeightGradientsAt(1)),
0363      fWeightsForgetStateGradients(this->GetWeightGradientsAt(5)), fForgetBiasGradients(this->GetBiasGradientsAt(1)),
0364      fWeightsCandidateGradients(this->GetWeightGradientsAt(2)),
0365      fWeightsCandidateStateGradients(this->GetWeightGradientsAt(6)),
0366      fCandidateBiasGradients(this->GetBiasGradientsAt(2)), fWeightsOutputGradients(this->GetWeightGradientsAt(3)),
0367      fWeightsOutputStateGradients(this->GetWeightGradientsAt(7)), fOutputBiasGradients(this->GetBiasGradientsAt(3))
0368 {
0369    for (size_t i = 0; i < timeSteps; ++i) {
0370       fDerivativesInput.emplace_back(batchSize, stateSize);
0371       fDerivativesForget.emplace_back(batchSize, stateSize);
0372       fDerivativesCandidate.emplace_back(batchSize, stateSize);
0373       fDerivativesOutput.emplace_back(batchSize, stateSize);
0374       input_gate_value.emplace_back(batchSize, stateSize);
0375       forget_gate_value.emplace_back(batchSize, stateSize);
0376       candidate_gate_value.emplace_back(batchSize, stateSize);
0377       output_gate_value.emplace_back(batchSize, stateSize);
0378       cell_value.emplace_back(batchSize, stateSize);
0379    }
0380    Architecture_t::InitializeLSTMTensors(this);
0381 }
0382 
0383  //______________________________________________________________________________
0384 template <typename Architecture_t>
0385 TBasicLSTMLayer<Architecture_t>::TBasicLSTMLayer(const TBasicLSTMLayer &layer)
0386    : VGeneralLayer<Architecture_t>(layer),
0387       fStateSize(layer.fStateSize),
0388       fCellSize(layer.fCellSize),
0389       fTimeSteps(layer.fTimeSteps),
0390       fRememberState(layer.fRememberState),
0391       fReturnSequence(layer.fReturnSequence),
0392       fF1(layer.GetActivationFunctionF1()),
0393       fF2(layer.GetActivationFunctionF2()),
0394       fInputValue(layer.GetBatchSize(), layer.GetStateSize()),
0395       fCandidateValue(layer.GetBatchSize(), layer.GetStateSize()),
0396       fForgetValue(layer.GetBatchSize(), layer.GetStateSize()),
0397       fOutputValue(layer.GetBatchSize(), layer.GetStateSize()),
0398       fState(layer.GetBatchSize(), layer.GetStateSize()),
0399       fCell(layer.GetBatchSize(), layer.GetCellSize()),
0400       fWeightsInputGate(this->GetWeightsAt(0)),
0401       fWeightsInputGateState(this->GetWeightsAt(4)),
0402       fInputGateBias(this->GetBiasesAt(0)),
0403       fWeightsForgetGate(this->GetWeightsAt(1)),
0404       fWeightsForgetGateState(this->GetWeightsAt(5)),
0405       fForgetGateBias(this->GetBiasesAt(1)),
0406       fWeightsCandidate(this->GetWeightsAt(2)),
0407       fWeightsCandidateState(this->GetWeightsAt(6)),
0408       fCandidateBias(this->GetBiasesAt(2)),
0409       fWeightsOutputGate(this->GetWeightsAt(3)),
0410       fWeightsOutputGateState(this->GetWeightsAt(7)),
0411       fOutputGateBias(this->GetBiasesAt(3)),
0412       fWeightsInputGradients(this->GetWeightGradientsAt(0)),
0413       fWeightsInputStateGradients(this->GetWeightGradientsAt(4)),
0414       fInputBiasGradients(this->GetBiasGradientsAt(0)),
0415       fWeightsForgetGradients(this->GetWeightGradientsAt(1)),
0416       fWeightsForgetStateGradients(this->GetWeightGradientsAt(5)),
0417       fForgetBiasGradients(this->GetBiasGradientsAt(1)),
0418       fWeightsCandidateGradients(this->GetWeightGradientsAt(2)),
0419       fWeightsCandidateStateGradients(this->GetWeightGradientsAt(6)),
0420       fCandidateBiasGradients(this->GetBiasGradientsAt(2)),
0421       fWeightsOutputGradients(this->GetWeightGradientsAt(3)),
0422       fWeightsOutputStateGradients(this->GetWeightGradientsAt(7)),
0423       fOutputBiasGradients(this->GetBiasGradientsAt(3))
0424 {
0425    for (size_t i = 0; i < fTimeSteps; ++i) {
0426       fDerivativesInput.emplace_back(layer.GetBatchSize(), layer.GetStateSize());
0427       Architecture_t::Copy(fDerivativesInput[i], layer.GetInputDerivativesAt(i));
0428 
0429       fDerivativesForget.emplace_back(layer.GetBatchSize(), layer.GetStateSize());
0430       Architecture_t::Copy(fDerivativesForget[i], layer.GetForgetDerivativesAt(i));
0431 
0432       fDerivativesCandidate.emplace_back(layer.GetBatchSize(), layer.GetStateSize());
0433       Architecture_t::Copy(fDerivativesCandidate[i], layer.GetCandidateDerivativesAt(i));
0434 
0435       fDerivativesOutput.emplace_back(layer.GetBatchSize(), layer.GetStateSize());
0436       Architecture_t::Copy(fDerivativesOutput[i], layer.GetOutputDerivativesAt(i));
0437 
0438       input_gate_value.emplace_back(layer.GetBatchSize(), layer.GetStateSize());
0439       Architecture_t::Copy(input_gate_value[i], layer.GetInputGateTensorAt(i));
0440 
0441       forget_gate_value.emplace_back(layer.GetBatchSize(), layer.GetStateSize());
0442       Architecture_t::Copy(forget_gate_value[i], layer.GetForgetGateTensorAt(i));
0443 
0444       candidate_gate_value.emplace_back(layer.GetBatchSize(), layer.GetStateSize());
0445       Architecture_t::Copy(candidate_gate_value[i], layer.GetCandidateGateTensorAt(i));
0446 
0447       output_gate_value.emplace_back(layer.GetBatchSize(), layer.GetStateSize());
0448       Architecture_t::Copy(output_gate_value[i], layer.GetOutputGateTensorAt(i));
0449 
0450       cell_value.emplace_back(layer.GetBatchSize(), layer.GetStateSize());
0451       Architecture_t::Copy(cell_value[i], layer.GetCellTensorAt(i));
0452    }
0453 
0454    // Gradient matrices not copied
0455    Architecture_t::Copy(fState, layer.GetState());
0456    Architecture_t::Copy(fCell, layer.GetCell());
0457 
0458    // Copy each gate values.
0459    Architecture_t::Copy(fInputValue, layer.GetInputGateValue());
0460    Architecture_t::Copy(fCandidateValue, layer.GetCandidateValue());
0461    Architecture_t::Copy(fForgetValue, layer.GetForgetGateValue());
0462    Architecture_t::Copy(fOutputValue, layer.GetOutputGateValue());
0463 
0464    Architecture_t::InitializeLSTMTensors(this);
0465 }
0466 
0467 //______________________________________________________________________________
0468 template <typename Architecture_t>
0469 void TBasicLSTMLayer<Architecture_t>::Initialize()
0470 {
0471    VGeneralLayer<Architecture_t>::Initialize();
0472 
0473    Architecture_t::InitializeLSTMDescriptors(fDescriptors, this);
0474    Architecture_t::InitializeLSTMWorkspace(fWorkspace, fDescriptors, this);
0475 }
0476 
0477 //______________________________________________________________________________
0478 template <typename Architecture_t>
0479 auto inline TBasicLSTMLayer<Architecture_t>::InputGate(const Matrix_t &input, Matrix_t &di)
0480 -> void
0481 {
0482    /*! Computes input gate values according to equation:
0483     *  input = act(W_input . input + W_state . state + bias)
0484     *  activation function: sigmoid. */
0485    const DNN::EActivationFunction fInp = this->GetActivationFunctionF1();
0486    Matrix_t tmpState(fInputValue.GetNrows(), fInputValue.GetNcols());
0487    Architecture_t::MultiplyTranspose(tmpState, fState, fWeightsInputGateState);
0488    Architecture_t::MultiplyTranspose(fInputValue, input, fWeightsInputGate);
0489    Architecture_t::ScaleAdd(fInputValue, tmpState);
0490    Architecture_t::AddRowWise(fInputValue, fInputGateBias);
0491    DNN::evaluateDerivativeMatrix<Architecture_t>(di, fInp, fInputValue);
0492    DNN::evaluateMatrix<Architecture_t>(fInputValue, fInp);
0493 }
0494 
0495  //______________________________________________________________________________
0496 template <typename Architecture_t>
0497 auto inline TBasicLSTMLayer<Architecture_t>::ForgetGate(const Matrix_t &input, Matrix_t &df)
0498 -> void
0499 {
0500    /*! Computes forget gate values according to equation:
0501     *  forget = act(W_input . input + W_state . state + bias)
0502     *  activation function: sigmoid. */
0503    const DNN::EActivationFunction fFor = this->GetActivationFunctionF1();
0504    Matrix_t tmpState(fForgetValue.GetNrows(), fForgetValue.GetNcols());
0505    Architecture_t::MultiplyTranspose(tmpState, fState, fWeightsForgetGateState);
0506    Architecture_t::MultiplyTranspose(fForgetValue, input, fWeightsForgetGate);
0507    Architecture_t::ScaleAdd(fForgetValue, tmpState);
0508    Architecture_t::AddRowWise(fForgetValue, fForgetGateBias);
0509    DNN::evaluateDerivativeMatrix<Architecture_t>(df, fFor, fForgetValue);
0510    DNN::evaluateMatrix<Architecture_t>(fForgetValue, fFor);
0511 }
0512 
0513  //______________________________________________________________________________
0514 template <typename Architecture_t>
0515 auto inline TBasicLSTMLayer<Architecture_t>::CandidateValue(const Matrix_t &input, Matrix_t &dc)
0516 -> void
0517 {
0518    /*! Candidate value will be used to scale input gate values followed by Hadamard product.
0519     *  candidate_value = act(W_input . input + W_state . state + bias)
0520     *  activation function = tanh. */
0521    const DNN::EActivationFunction fCan = this->GetActivationFunctionF2();
0522    Matrix_t tmpState(fCandidateValue.GetNrows(), fCandidateValue.GetNcols());
0523    Architecture_t::MultiplyTranspose(tmpState, fState, fWeightsCandidateState);
0524    Architecture_t::MultiplyTranspose(fCandidateValue, input, fWeightsCandidate);
0525    Architecture_t::ScaleAdd(fCandidateValue, tmpState);
0526    Architecture_t::AddRowWise(fCandidateValue, fCandidateBias);
0527    DNN::evaluateDerivativeMatrix<Architecture_t>(dc, fCan, fCandidateValue);
0528    DNN::evaluateMatrix<Architecture_t>(fCandidateValue, fCan);
0529 }
0530 
0531  //______________________________________________________________________________
0532 template <typename Architecture_t>
0533 auto inline TBasicLSTMLayer<Architecture_t>::OutputGate(const Matrix_t &input, Matrix_t &dout)
0534 -> void
0535 {
0536    /*! Output gate values will be used to calculate next hidden state and output values.
0537     *  output = act(W_input . input + W_state . state + bias)
0538     *  activation function = sigmoid. */
0539    const DNN::EActivationFunction fOut = this->GetActivationFunctionF1();
0540    Matrix_t tmpState(fOutputValue.GetNrows(), fOutputValue.GetNcols());
0541    Architecture_t::MultiplyTranspose(tmpState, fState, fWeightsOutputGateState);
0542    Architecture_t::MultiplyTranspose(fOutputValue, input, fWeightsOutputGate);
0543    Architecture_t::ScaleAdd(fOutputValue, tmpState);
0544    Architecture_t::AddRowWise(fOutputValue, fOutputGateBias);
0545    DNN::evaluateDerivativeMatrix<Architecture_t>(dout, fOut, fOutputValue);
0546    DNN::evaluateMatrix<Architecture_t>(fOutputValue, fOut);
0547 }
0548 
0549 
0550 
0551  //______________________________________________________________________________
0552 template <typename Architecture_t>
0553 auto inline TBasicLSTMLayer<Architecture_t>::Forward(Tensor_t &input, bool  isTraining )
0554 -> void
0555 {
0556 
0557    // for Cudnn
0558    if (Architecture_t::IsCudnn()) {
0559 
0560       // input size is stride[1] of input tensor that is B x T x inputSize
0561       assert(input.GetStrides()[1] == this->GetInputSize());
0562 
0563       Tensor_t &x = this->fX;
0564       Tensor_t &y = this->fY;
0565       Architecture_t::Rearrange(x, input);
0566 
0567       //const auto &weights = this->GetWeightsAt(0);
0568       const auto &weights = this->GetWeightsTensor();
0569       // Tensor_t cx({1}); // not used for normal RNN
0570       // Tensor_t cy({1}); // not used for normal RNN
0571 
0572       // hx is fState - tensor are of right shape
0573       auto &hx = this->fState;
0574       //auto &cx = this->fCell;
0575       auto &cx = this->fCell; // pass an empty cell state
0576       // use same for hy and cy
0577       auto &hy = this->fState;
0578       auto &cy = this->fCell;
0579 
0580       auto & rnnDesc = static_cast<RNNDescriptors_t &>(*fDescriptors);
0581       auto & rnnWork = static_cast<RNNWorkspace_t &>(*fWorkspace);
0582 
0583       Architecture_t::RNNForward(x, hx, cx, weights, y, hy, cy, rnnDesc, rnnWork, isTraining);
0584 
0585       if (fReturnSequence) {
0586          Architecture_t::Rearrange(this->GetOutput(), y); // swap B and T from y to Output
0587       } else {
0588          // tmp is a reference to y (full cudnn output)
0589          Tensor_t tmp = (y.At(y.GetShape()[0] - 1)).Reshape({y.GetShape()[1], 1, y.GetShape()[2]});
0590          Architecture_t::Copy(this->GetOutput(), tmp);
0591       }
0592 
0593       return;
0594    }
0595 
0596    // Standard CPU implementation
0597 
0598    // D : input size
0599    // H : state size
0600    // T : time size
0601    // B : batch size
0602 
0603    Tensor_t arrInput( fTimeSteps, this->GetBatchSize(), this->GetInputWidth());
0604    //Tensor_t &arrInput = this->GetX();
0605 
0606    Architecture_t::Rearrange(arrInput, input); // B x T x D
0607 
0608    Tensor_t arrOutput ( fTimeSteps, this->GetBatchSize(), fStateSize);
0609 
0610 
0611    if (!this->fRememberState) {
0612       InitState(DNN::EInitialization::kZero);
0613    }
0614 
0615    /*! Pass each gate values to CellForward() to calculate
0616     *  next hidden state and next cell state. */
0617    for (size_t t = 0; t < fTimeSteps; ++t) {
0618       /* Feed forward network: value of each gate being computed at each timestep t. */
0619       Matrix_t arrInputMt = arrInput[t];
0620       InputGate(arrInputMt, fDerivativesInput[t]);
0621       ForgetGate(arrInputMt, fDerivativesForget[t]);
0622       CandidateValue(arrInputMt, fDerivativesCandidate[t]);
0623       OutputGate(arrInputMt, fDerivativesOutput[t]);
0624 
0625       Architecture_t::Copy(this->GetInputGateTensorAt(t), fInputValue);
0626       Architecture_t::Copy(this->GetForgetGateTensorAt(t), fForgetValue);
0627       Architecture_t::Copy(this->GetCandidateGateTensorAt(t), fCandidateValue);
0628       Architecture_t::Copy(this->GetOutputGateTensorAt(t), fOutputValue);
0629 
0630       CellForward(fInputValue, fForgetValue, fCandidateValue, fOutputValue);
0631       Matrix_t arrOutputMt = arrOutput[t];
0632       Architecture_t::Copy(arrOutputMt, fState);
0633       Architecture_t::Copy(this->GetCellTensorAt(t), fCell);
0634    }
0635 
0636    // check if full output needs to be returned
0637    if (fReturnSequence)
0638       Architecture_t::Rearrange(this->GetOutput(), arrOutput); // B x T x D
0639    else {
0640       // get T[end[]]
0641       Tensor_t tmp = arrOutput.At(fTimeSteps - 1); // take last time step
0642       // shape of tmp is  for CPU (columnwise) B x D ,   need to reshape to  make a B x D x 1
0643       //  and transpose it to 1 x D x B  (this is how output is expected in columnmajor format)
0644       tmp = tmp.Reshape( {tmp.GetShape()[0], tmp.GetShape()[1], 1});
0645       assert(tmp.GetSize() == this->GetOutput().GetSize());
0646       assert( tmp.GetShape()[0] == this->GetOutput().GetShape()[2]);  // B is last dim in output and first in tmp
0647       Architecture_t::Rearrange(this->GetOutput(), tmp);
0648       // keep array output
0649       fY = arrOutput;
0650    }
0651 }
0652 
0653  //______________________________________________________________________________
0654 template <typename Architecture_t>
0655 auto inline TBasicLSTMLayer<Architecture_t>::CellForward(Matrix_t &inputGateValues, const Matrix_t &forgetGateValues,
0656                                                          const Matrix_t &candidateValues, const Matrix_t &outputGateValues)
0657 -> void
0658 {
0659 
0660    // Update cell state.
0661    Architecture_t::Hadamard(fCell, forgetGateValues);
0662    Architecture_t::Hadamard(inputGateValues, candidateValues);
0663    Architecture_t::ScaleAdd(fCell, inputGateValues);
0664 
0665    Matrix_t cache(fCell.GetNrows(), fCell.GetNcols());
0666    Architecture_t::Copy(cache, fCell);
0667 
0668    // Update hidden state.
0669    const DNN::EActivationFunction fAT = this->GetActivationFunctionF2();
0670    DNN::evaluateMatrix<Architecture_t>(cache, fAT);
0671 
0672    /*! The Hadamard product of output_gate_value . tanh(cell_state)
0673     *  will be copied to next hidden state (passed to next LSTM cell)
0674     *  and we will update our outputGateValues also. */
0675    Architecture_t::Copy(fState, cache);
0676    Architecture_t::Hadamard(fState, outputGateValues);
0677 }
0678 
0679  //____________________________________________________________________________
0680 template <typename Architecture_t>
0681 auto inline TBasicLSTMLayer<Architecture_t>::Backward(Tensor_t &gradients_backward,           // B x T x D
0682                                                       const Tensor_t &activations_backward)   // B x T x D
0683 -> void
0684 {
0685 
0686    // BACKWARD for CUDNN
0687    if (Architecture_t::IsCudnn()) {
0688 
0689       Tensor_t &x = this->fX;
0690       Tensor_t &y = this->fY;
0691       Tensor_t &dx = this->fDx;
0692       Tensor_t &dy = this->fDy;
0693 
0694       // input size is stride[1] of input tensor that is B x T x inputSize
0695       assert(activations_backward.GetStrides()[1] == this->GetInputSize());
0696 
0697       Architecture_t::Rearrange(x, activations_backward);
0698 
0699       if (!fReturnSequence) {
0700 
0701          // Architecture_t::InitializeZero(dy);
0702          Architecture_t::InitializeZero(dy);
0703 
0704          // Tensor_t tmp1 = y.At(y.GetShape()[0] - 1).Reshape({y.GetShape()[1], 1, y.GetShape()[2]});
0705          // dy is a tensor of shape (rowmajor for Cudnn): T x B x S
0706          // and this->ActivationGradients is  B x (T=1) x S
0707          Tensor_t tmp2 = dy.At(dy.GetShape()[0] - 1).Reshape({dy.GetShape()[1], 1, dy.GetShape()[2]});
0708 
0709          // Architecture_t::Copy(tmp1, this->GetOutput());
0710          Architecture_t::Copy(tmp2, this->GetActivationGradients());
0711       } else {
0712          Architecture_t::Rearrange(y, this->GetOutput());
0713          Architecture_t::Rearrange(dy, this->GetActivationGradients());
0714       }
0715 
0716       // Architecture_t::PrintTensor(this->GetOutput(), "output before bwd");
0717 
0718       // for cudnn Matrix_t and Tensor_t are same type
0719       const auto &weights = this->GetWeightsTensor();
0720       auto &weightGradients = this->GetWeightGradientsTensor();
0721       // note that cudnnRNNBackwardWeights accumulate the weight gradients.
0722       // We need then to initialize the tensor to zero every time
0723       Architecture_t::InitializeZero(weightGradients);
0724 
0725       // hx is fState
0726       auto &hx = this->GetState();
0727       auto &cx = this->GetCell();
0728       //auto &cx = this->GetCell();
0729       // use same for hy and cy
0730       auto &dhy = hx;
0731       auto &dcy = cx;
0732       auto &dhx = hx;
0733       auto &dcx = cx;
0734 
0735       auto & rnnDesc = static_cast<RNNDescriptors_t &>(*fDescriptors);
0736       auto & rnnWork = static_cast<RNNWorkspace_t &>(*fWorkspace);
0737 
0738       Architecture_t::RNNBackward(x, hx, cx, y, dy, dhy, dcy, weights, dx, dhx, dcx, weightGradients, rnnDesc, rnnWork);
0739 
0740       // Architecture_t::PrintTensor(this->GetOutput(), "output after bwd");
0741 
0742       if (gradients_backward.GetSize() != 0)
0743          Architecture_t::Rearrange(gradients_backward, dx);
0744 
0745       return;
0746    }
0747    // CPU implementation
0748 
0749    // gradients_backward is activationGradients of layer before it, which is input layer.
0750    // Currently, gradients_backward is for input(x) and not for state.
0751    // For the state it can be:
0752    Matrix_t state_gradients_backward(this->GetBatchSize(), fStateSize); // B x H
0753    DNN::initialize<Architecture_t>(state_gradients_backward, DNN::EInitialization::kZero); // B x H
0754 
0755 
0756    Matrix_t cell_gradients_backward(this->GetBatchSize(), fStateSize); // B x H
0757    DNN::initialize<Architecture_t>(cell_gradients_backward, DNN::EInitialization::kZero); // B x H
0758 
0759    // if dummy is false gradients_backward will be written back on the matrix
0760    bool dummy = false;
0761    if (gradients_backward.GetSize() == 0 || gradients_backward[0].GetNrows() == 0 || gradients_backward[0].GetNcols() == 0) {
0762       dummy = true;
0763    }
0764 
0765 
0766    Tensor_t arr_gradients_backward ( fTimeSteps, this->GetBatchSize(), this->GetInputSize());
0767 
0768 
0769    //Architecture_t::Rearrange(arr_gradients_backward, gradients_backward); // B x T x D
0770    // activations_backward is input.
0771    Tensor_t arr_activations_backward ( fTimeSteps, this->GetBatchSize(), this->GetInputSize());
0772 
0773    Architecture_t::Rearrange(arr_activations_backward, activations_backward); // B x T x D
0774 
0775    /*! For backpropagation, we need to calculate loss. For loss, output must be known.
0776     *  We obtain outputs during forward propagation and place the results in arr_output tensor. */
0777    Tensor_t arr_output (  fTimeSteps, this->GetBatchSize(), fStateSize);
0778 
0779    Matrix_t initState(this->GetBatchSize(), fCellSize); // B x H
0780    DNN::initialize<Architecture_t>(initState, DNN::EInitialization::kZero); // B x H
0781 
0782    // This will take partial derivative of state[t] w.r.t state[t-1]
0783 
0784    Tensor_t arr_actgradients(fTimeSteps, this->GetBatchSize(), fStateSize);
0785 
0786    if (fReturnSequence) {
0787       Architecture_t::Rearrange(arr_output, this->GetOutput());
0788       Architecture_t::Rearrange(arr_actgradients, this->GetActivationGradients());
0789    } else {
0790       // here for CPU need to transpose the input activation gradients into the right format
0791       arr_output = fY;
0792       Architecture_t::InitializeZero(arr_actgradients);
0793       // need to reshape to pad a time dimension = 1 (note here is columnmajor tensors)
0794       Tensor_t tmp_grad = arr_actgradients.At(fTimeSteps - 1).Reshape( {this->GetBatchSize(), fStateSize, 1});
0795       assert(tmp_grad.GetSize() == this->GetActivationGradients().GetSize());
0796       assert(tmp_grad.GetShape()[0] == this->GetActivationGradients().GetShape()[2]);  // B in tmp is [0] and [2] in input act. gradients
0797 
0798       Architecture_t::Rearrange(tmp_grad, this->GetActivationGradients());
0799    }
0800 
0801    /*! There are total 8 different weight matrices and 4 bias vectors.
0802     *  Re-initialize them with zero because it should have some value. (can't be garbage values) */
0803 
0804    // Input Gate.
0805    fWeightsInputGradients.Zero();
0806    fWeightsInputStateGradients.Zero();
0807    fInputBiasGradients.Zero();
0808 
0809    // Forget Gate.
0810    fWeightsForgetGradients.Zero();
0811    fWeightsForgetStateGradients.Zero();
0812    fForgetBiasGradients.Zero();
0813 
0814    // Candidate Gate.
0815    fWeightsCandidateGradients.Zero();
0816    fWeightsCandidateStateGradients.Zero();
0817    fCandidateBiasGradients.Zero();
0818 
0819    // Output Gate.
0820    fWeightsOutputGradients.Zero();
0821    fWeightsOutputStateGradients.Zero();
0822    fOutputBiasGradients.Zero();
0823 
0824 
0825    for (size_t t = fTimeSteps; t > 0; t--) {
0826       // Store the sum of gradients obtained at each timestep during backward pass.
0827       Architecture_t::ScaleAdd(state_gradients_backward, arr_actgradients[t-1]);
0828       if (t > 1) {
0829          const Matrix_t &prevStateActivations = arr_output[t-2];
0830          const Matrix_t &prevCellActivations = this->GetCellTensorAt(t-2);
0831          // During forward propagation, each gate value calculates their gradients.
0832          Matrix_t dx = arr_gradients_backward[t-1];
0833          CellBackward(state_gradients_backward, cell_gradients_backward,
0834                       prevStateActivations, prevCellActivations,
0835                       this->GetInputGateTensorAt(t-1), this->GetForgetGateTensorAt(t-1),
0836                       this->GetCandidateGateTensorAt(t-1), this->GetOutputGateTensorAt(t-1),
0837                       arr_activations_backward[t-1], dx,
0838                       fDerivativesInput[t-1], fDerivativesForget[t-1],
0839                       fDerivativesCandidate[t-1], fDerivativesOutput[t-1], t-1);
0840       } else {
0841          const Matrix_t &prevStateActivations = initState;
0842          const Matrix_t &prevCellActivations = initState;
0843          Matrix_t dx = arr_gradients_backward[t-1];
0844          CellBackward(state_gradients_backward, cell_gradients_backward,
0845                       prevStateActivations, prevCellActivations,
0846                       this->GetInputGateTensorAt(t-1), this->GetForgetGateTensorAt(t-1),
0847                       this->GetCandidateGateTensorAt(t-1), this->GetOutputGateTensorAt(t-1),
0848                       arr_activations_backward[t-1], dx,
0849                       fDerivativesInput[t-1], fDerivativesForget[t-1],
0850                       fDerivativesCandidate[t-1], fDerivativesOutput[t-1], t-1);
0851         }
0852    }
0853 
0854    if (!dummy) {
0855       Architecture_t::Rearrange(gradients_backward, arr_gradients_backward );
0856    }
0857 
0858 }
0859 
0860 
0861  //______________________________________________________________________________
0862 template <typename Architecture_t>
0863 auto inline TBasicLSTMLayer<Architecture_t>::CellBackward(Matrix_t & state_gradients_backward,
0864                                                           Matrix_t & cell_gradients_backward,
0865                                                           const Matrix_t & precStateActivations, const Matrix_t & precCellActivations,
0866                                                           const Matrix_t & input_gate, const Matrix_t & forget_gate,
0867                                                           const Matrix_t & candidate_gate, const Matrix_t & output_gate,
0868                                                           const Matrix_t & input, Matrix_t & input_gradient,
0869                                                           Matrix_t &di, Matrix_t &df, Matrix_t &dc, Matrix_t &dout,
0870                                                           size_t t)
0871 -> Matrix_t &
0872 {
0873    /*! Call here LSTMLayerBackward() to pass parameters i.e. gradient
0874     *  values obtained from each gate during forward propagation. */
0875 
0876 
0877    // cell gradient for current time step
0878    const DNN::EActivationFunction fAT = this->GetActivationFunctionF2();
0879    Matrix_t cell_gradient(this->GetCellTensorAt(t).GetNrows(), this->GetCellTensorAt(t).GetNcols());
0880    DNN::evaluateDerivativeMatrix<Architecture_t>(cell_gradient, fAT, this->GetCellTensorAt(t));
0881 
0882    // cell tanh value for current time step
0883    Matrix_t cell_tanh(this->GetCellTensorAt(t).GetNrows(), this->GetCellTensorAt(t).GetNcols());
0884    Architecture_t::Copy(cell_tanh, this->GetCellTensorAt(t));
0885    DNN::evaluateMatrix<Architecture_t>(cell_tanh, fAT);
0886 
0887    return Architecture_t::LSTMLayerBackward(state_gradients_backward, cell_gradients_backward,
0888                                             fWeightsInputGradients, fWeightsForgetGradients, fWeightsCandidateGradients,
0889                                             fWeightsOutputGradients, fWeightsInputStateGradients, fWeightsForgetStateGradients,
0890                                             fWeightsCandidateStateGradients, fWeightsOutputStateGradients, fInputBiasGradients, fForgetBiasGradients,
0891                                             fCandidateBiasGradients, fOutputBiasGradients, di, df, dc, dout,
0892                                             precStateActivations, precCellActivations,
0893                                             input_gate, forget_gate, candidate_gate, output_gate,
0894                                             fWeightsInputGate, fWeightsForgetGate, fWeightsCandidate, fWeightsOutputGate,
0895                                             fWeightsInputGateState, fWeightsForgetGateState, fWeightsCandidateState,
0896                                             fWeightsOutputGateState, input, input_gradient,
0897                                             cell_gradient, cell_tanh);
0898 }
0899 
0900  //______________________________________________________________________________
0901 template <typename Architecture_t>
0902 auto TBasicLSTMLayer<Architecture_t>::InitState(DNN::EInitialization /* m */)
0903 -> void
0904 {
0905    DNN::initialize<Architecture_t>(this->GetState(),  DNN::EInitialization::kZero);
0906    DNN::initialize<Architecture_t>(this->GetCell(),  DNN::EInitialization::kZero);
0907 }
0908 
0909  //______________________________________________________________________________
0910 template<typename Architecture_t>
0911 auto TBasicLSTMLayer<Architecture_t>::Print() const
0912 -> void
0913 {
0914    std::cout << " LSTM Layer: \t ";
0915    std::cout << " (NInput = " << this->GetInputSize();  // input size
0916    std::cout << ", NState = " << this->GetStateSize();  // hidden state size
0917    std::cout << ", NTime  = " << this->GetTimeSteps() << " )";  // time size
0918    std::cout << "\tOutput = ( " << this->GetOutput().GetFirstSize() << " , " << this->GetOutput()[0].GetNrows() << " , " << this->GetOutput()[0].GetNcols() << " )\n";
0919 }
0920 
0921  //______________________________________________________________________________
0922 template <typename Architecture_t>
0923 auto inline TBasicLSTMLayer<Architecture_t>::AddWeightsXMLTo(void *parent)
0924 -> void
0925 {
0926    auto layerxml = gTools().xmlengine().NewChild(parent, nullptr, "LSTMLayer");
0927 
0928    // Write all other info like outputSize, cellSize, inputSize, timeSteps, rememberState
0929    gTools().xmlengine().NewAttr(layerxml, nullptr, "StateSize", gTools().StringFromInt(this->GetStateSize()));
0930    gTools().xmlengine().NewAttr(layerxml, nullptr, "CellSize", gTools().StringFromInt(this->GetCellSize()));
0931    gTools().xmlengine().NewAttr(layerxml, nullptr, "InputSize", gTools().StringFromInt(this->GetInputSize()));
0932    gTools().xmlengine().NewAttr(layerxml, nullptr, "TimeSteps", gTools().StringFromInt(this->GetTimeSteps()));
0933    gTools().xmlengine().NewAttr(layerxml, nullptr, "RememberState", gTools().StringFromInt(this->DoesRememberState()));
0934    gTools().xmlengine().NewAttr(layerxml, nullptr, "ReturnSequence", gTools().StringFromInt(this->DoesReturnSequence()));
0935 
0936    // write weights and bias matrices
0937    this->WriteMatrixToXML(layerxml, "InputWeights", this->GetWeightsAt(0));
0938    this->WriteMatrixToXML(layerxml, "InputStateWeights", this->GetWeightsAt(1));
0939    this->WriteMatrixToXML(layerxml, "InputBiases", this->GetBiasesAt(0));
0940    this->WriteMatrixToXML(layerxml, "ForgetWeights", this->GetWeightsAt(2));
0941    this->WriteMatrixToXML(layerxml, "ForgetStateWeights", this->GetWeightsAt(3));
0942    this->WriteMatrixToXML(layerxml, "ForgetBiases", this->GetBiasesAt(1));
0943    this->WriteMatrixToXML(layerxml, "CandidateWeights", this->GetWeightsAt(4));
0944    this->WriteMatrixToXML(layerxml, "CandidateStateWeights", this->GetWeightsAt(5));
0945    this->WriteMatrixToXML(layerxml, "CandidateBiases", this->GetBiasesAt(2));
0946    this->WriteMatrixToXML(layerxml, "OuputWeights", this->GetWeightsAt(6));
0947    this->WriteMatrixToXML(layerxml, "OutputStateWeights", this->GetWeightsAt(7));
0948    this->WriteMatrixToXML(layerxml, "OutputBiases", this->GetBiasesAt(3));
0949 }
0950 
0951  //______________________________________________________________________________
0952 template <typename Architecture_t>
0953 auto inline TBasicLSTMLayer<Architecture_t>::ReadWeightsFromXML(void *parent)
0954 -> void
0955 {
0956     // Read weights and biases
0957    this->ReadMatrixXML(parent, "InputWeights", this->GetWeightsAt(0));
0958    this->ReadMatrixXML(parent, "InputStateWeights", this->GetWeightsAt(1));
0959    this->ReadMatrixXML(parent, "InputBiases", this->GetBiasesAt(0));
0960    this->ReadMatrixXML(parent, "ForgetWeights", this->GetWeightsAt(2));
0961    this->ReadMatrixXML(parent, "ForgetStateWeights", this->GetWeightsAt(3));
0962    this->ReadMatrixXML(parent, "ForgetBiases", this->GetBiasesAt(1));
0963    this->ReadMatrixXML(parent, "CandidateWeights", this->GetWeightsAt(4));
0964    this->ReadMatrixXML(parent, "CandidateStateWeights", this->GetWeightsAt(5));
0965    this->ReadMatrixXML(parent, "CandidateBiases", this->GetBiasesAt(2));
0966    this->ReadMatrixXML(parent, "OuputWeights", this->GetWeightsAt(6));
0967    this->ReadMatrixXML(parent, "OutputStateWeights", this->GetWeightsAt(7));
0968    this->ReadMatrixXML(parent, "OutputBiases", this->GetBiasesAt(3));
0969 }
0970 
0971 } // namespace LSTM
0972 } // namespace DNN
0973 } // namespace TMVA
0974 
0975 #endif // LSTM_LAYER_H