Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-09-12 09:09:39

0001 #ifndef TMVA_SOFIE_ROPERATOR_GRU_I
0002 #define TMVA_SOFIE_ROPERATOR_GRU_I
0003 
0004 namespace TMVA {
0005 namespace Experimental {
0006 namespace SOFIE {
0007 
0008 template <typename T>
0009 auto ROperator_GRU<T>::TypeInference(std::vector<ETensorType> input)
0010 -> std::vector<ETensorType> {
0011     ETensorType out = input[0];
0012     return {out, out};
0013 }
0014 
0015 template<typename T>
0016 auto ROperator_GRU<T>::ShapeInference(std::vector<std::vector<size_t>> input)
0017 -> std::vector<std::vector<size_t>> {
0018     size_t num_directions = input[1][0];
0019     size_t hidden_size = input[1][1] / 3;
0020     if (fAttrLayout == 0) {
0021         size_t seq_length = input[0][0];
0022         size_t batch_size = input[0][1];
0023         std::vector<std::vector<size_t>> ret(
0024             {{seq_length, num_directions, batch_size, hidden_size},
0025             {num_directions, batch_size, hidden_size}});
0026         return ret;
0027     } else {
0028         size_t batch_size = input[0][0];
0029         size_t seq_length = input[0][1];
0030         std::vector<std::vector<size_t>> ret(
0031             {{batch_size, seq_length, num_directions, hidden_size},
0032             {batch_size, num_directions, hidden_size}});
0033         return ret;
0034     }
0035 }
0036 
0037 template<typename T>
0038 void ROperator_GRU<T>::Initialize(RModel& model){
0039 
0040    fUseSession = model.UseSession();
0041    // Check the input and output tensors
0042    if (!model.CheckIfTensorAlreadyExist(fNX)) {
0043       throw std::runtime_error("TMVA SOFIE GRU Op input tensor " + fNX + "  is not found in model.");
0044    }
0045    fShapeX = model.GetTensorShape(fNX);
0046    if (fShapeX.size() != 3) {
0047       throw std::runtime_error("TMVA SOFIE GRU Op input tensor " + fNX + " is not of 3 dimensions.");
0048    }
0049    if (!model.CheckIfTensorAlreadyExist(fNW)) {
0050       throw std::runtime_error("TMVA SOFIE GRU Op input tensor " + fNW + "  is not found in model.");
0051    }
0052    fShapeW = model.GetTensorShape(fNW);
0053    if (fShapeW.size() != 3) {
0054       throw std::runtime_error("TMVA SOFIE GRU Op input tensor " + fNW + " is not of 3 dimensions.");
0055    }
0056    if (!model.CheckIfTensorAlreadyExist(fNR)) {
0057       throw std::runtime_error("TMVA SOFIE GRU Op input tensor " + fNR + "  is not found in model.");
0058    }
0059    fShapeR = model.GetTensorShape(fNR);
0060    if (fShapeR.size() != 3) {
0061       throw std::runtime_error("TMVA SOFIE GRU Op input tensor " + fNR + " is not of 3 dimensions.");
0062    }
0063    if (!fNB.empty()) {
0064       if (!model.CheckIfTensorAlreadyExist(fNB)) {
0065          throw std::runtime_error("TMVA SOFIE GRU op input tensor " + fNB + " is not  found in model.");
0066       }
0067       fShapeB = model.GetTensorShape(fNB);
0068       if (fShapeB.size() != 2 && fShapeB.size() != 4) {
0069          throw std::runtime_error("TMVA SOFIE GRU op input tensor " + fNB + " is not of 2 or 4 dimensions.");
0070       }
0071       if (fShapeB.size() == 2) {
0072          // Broadcasting the bias
0073          auto original_data = model.GetInitializedTensorData(fNB);
0074          size_t num_directions = fShapeW[0];
0075          size_t batch_size = (fAttrLayout == 0) ? fShapeX[1] : fShapeX[0];
0076          size_t seq_length = (fAttrLayout == 0) ? fShapeX[0] : fShapeX[1];
0077          if (fType == "float") {
0078             float *original_bias = static_cast<float*>(original_data.get());
0079             float *new_bias = new float[num_directions * 6 * seq_length * batch_size * fAttrHiddenSize];
0080             for (size_t direction = 0; direction < num_directions; direction++) {
0081                for (size_t i = 0; i < 6; i++) {
0082                   for (size_t seq = 0; seq < seq_length; seq++) {
0083                      for (size_t batch = 0; batch < batch_size; batch++) {
0084                         size_t bias_offset = direction * 6 * fAttrHiddenSize + i * fAttrHiddenSize;
0085                         size_t offset = direction * 6 * batch_size * seq_length * fAttrHiddenSize +
0086                                        i * batch_size * seq_length * fAttrHiddenSize +
0087                                        + seq *batch_size *fAttrHiddenSize + batch *fAttrHiddenSize;
0088                         std::copy(original_bias + bias_offset, original_bias + bias_offset + fAttrHiddenSize,
0089                                   new_bias + offset);
0090                      }
0091                   }
0092                }
0093             }
0094 
0095             std::vector<size_t> new_bias_shape = {num_directions, 6, seq_length, batch_size, fAttrHiddenSize};
0096             std::shared_ptr<void> new_bias_ptr(new_bias, std::default_delete<float[]>());
0097             model.UpdateInitializedTensor(fNB, model.GetTensorType(fNB), new_bias_shape, new_bias_ptr);
0098             fShapeB = model.GetTensorShape(fNB);
0099          }
0100       }
0101    }
0102    if (!fNSequence_lens.empty()) {
0103       if (!model.CheckIfTensorAlreadyExist(fNSequence_lens)) {
0104          throw std::runtime_error("TMVA SOFIE GRU Op input tensor " +
0105                                   fNSequence_lens +
0106                                   "is not found in model.");
0107       }
0108       fShapeSequence_lens = model.GetTensorShape(fNSequence_lens);
0109       if (fShapeSequence_lens.size() != 1) {
0110          throw std::runtime_error("TMVA SOFIE GRU Op input tensor " +
0111                                   fNSequence_lens +
0112                                   " is not of 1 dimension.");
0113       }
0114    }
0115    if (!fNInitial_h.empty()) {
0116       if (!model.CheckIfTensorAlreadyExist(fNInitial_h)) {
0117          throw std::runtime_error("TMVA SOFIE GRU Op input tensor " +
0118                                   fNInitial_h + " is not found in model.");
0119       }
0120       fShapeInitial_h = model.GetTensorShape(fNInitial_h);
0121       if (fShapeInitial_h.size() != 3) {
0122          throw std::runtime_error("TMVA SOFIE GRU Op input tensor " +
0123                                   fNInitial_h + " is not of 3 dimensions.");
0124       }
0125    }
0126    if (!fNY.empty()) {
0127       fShapeY = ShapeInference({fShapeX, fShapeW})[0];
0128       if (!model.CheckIfTensorAlreadyExist(fNY)) {
0129          model.AddIntermediateTensor(fNY, model.GetTensorType(fNX), fShapeY);
0130       }
0131    }
0132    if (!fNY_h.empty()) {
0133       fShapeY_h = ShapeInference({fShapeX, fShapeW})[1];
0134       if (!model.CheckIfTensorAlreadyExist(fNY_h)) {
0135          model.AddIntermediateTensor(fNY_h, model.GetTensorType(fNX), fShapeY_h);
0136       }
0137    }
0138    // Check the attributes
0139    for (auto &activation : fAttrActivations) {
0140       if (activation != "Relu" && activation != "Tanh" &&
0141           activation != "Sigmoid" && activation != "Affine" &&
0142           activation != "LeakyRelu" && activation != "ThresholdRelu" &&
0143           activation != "ScaledTanh" && activation != "HardSigmoid" &&
0144           activation != "Elu" && activation != "Softsign" &&
0145           activation != "Softplus") {
0146          throw std::runtime_error("TMVA SOFIE - Activation function " +
0147                                   activation + " not implemented");
0148       }
0149    }
0150    if (fAttrDirection == "reverse") fAttrDirection = "backward";
0151    if (fAttrDirection != "forward" && fAttrDirection != "backward" &&
0152        fAttrDirection != "reverse" &&
0153        fAttrDirection != "bidirectional") {
0154       throw std::runtime_error(
0155           "TMVA SOFIE - Invalid GRU direction fAttrDirection = " +
0156           fAttrDirection);
0157    }
0158    if (3 * fAttrHiddenSize != fShapeW[1]) {
0159       throw std::runtime_error(
0160           "TMVA SOFIE - fAttrHiddenSize must be equal to " +
0161           std::to_string(fShapeW[1] / 3));
0162    }
0163    if (fAttrLayout > 1) {
0164       throw std::runtime_error("TMVA SOFIE - Layout fAttrLayout = " +
0165                                std::to_string(fAttrLayout) +
0166                                " must be 0 (timewise) or 1 (batchwise)");
0167    }
0168    if (fAttrLinearBeforeReset > 1) {
0169       throw std::runtime_error(
0170          "TMVA SOFIE - fAttrInputForget = " + std::to_string(fAttrLinearBeforeReset)
0171          + " must be 0 or 1.");
0172    }
0173    if (fAttrActivations.empty()) {
0174       if (fAttrDirection == "bidirectional") {
0175          fAttrActivations = {"Sigmoid", "Tanh", "Sigmoid", "Tanh"};
0176       } else {
0177          fAttrActivations = {"Sigmoid", "Tanh"};
0178       }
0179    }
0180 }
0181 
0182 // generate code for Session data members (e.g. internal vectors)
0183 template <typename T>
0184 std::string ROperator_GRU<T>::GenerateSessionMembersCode(std::string opName)
0185 {
0186    opName = "op_" + opName;
0187    std::stringstream out;
0188 
0189    size_t num_directions = fShapeW[0];
0190    size_t seq_length = (fAttrLayout == 0) ? fShapeX[0] : fShapeX[1];
0191    size_t batch_size = (fAttrLayout == 0) ? fShapeX[1] : fShapeX[0];
0192    size_t input_size = fShapeX[2];
0193 
0194    if (fAttrLayout != 0) {
0195       out << "std::vector<" << fType << "> fVec_" << opName << "_input = std::vector<" << fType << ">("
0196           << seq_length * batch_size * input_size << ");\n";
0197       out << "std::vector<" << fType << "> fVec_" << opName << "_initial_hidden_state = std::vector<" << fType << ">("
0198           << num_directions * batch_size * fAttrHiddenSize << ");\n";
0199       out << "std::vector<" << fType << "> fVec_" << opName << "_initial_cell_state = std::vector<" << fType << ">("
0200           << num_directions * batch_size * fAttrHiddenSize << ");\n";
0201    }
0202    // Set the feedforward
0203    size_t ff_size = seq_length * batch_size * fAttrHiddenSize;
0204    out << "std::vector<" << fType << "> fVec_" << opName << "_f_update_gate = std::vector<" << fType << ">(" << ff_size << ");\n";
0205    out << "std::vector<" << fType << "> fVec_" << opName << "_f_reset_gate = std::vector<" << fType << ">(" << ff_size << ");\n";
0206    out << "std::vector<" << fType << "> fVec_" << opName << "_f_hidden_gate = std::vector<" << fType << ">(" << ff_size << ");\n";
0207    // gate results
0208    size_t hs_size = seq_length * num_directions * batch_size * fAttrHiddenSize;
0209    out << "std::vector<" << fType << "> fVec_" << opName << "_update_gate = std::vector<" << fType << ">(" << hs_size << ");\n";
0210    out << "std::vector<" << fType << "> fVec_" << opName << "_reset_gate = std::vector<" << fType << ">(" << hs_size << ");\n";
0211    out << "std::vector<" << fType << "> fVec_" << opName << "_hidden_gate = std::vector<" << fType << ">(" << hs_size << ");\n";
0212 
0213    // feedback
0214    out << "std::vector<" << fType << "> fVec_" << opName << "_feedback = std::vector<" << fType << ">("
0215        << batch_size * fAttrHiddenSize << ");\n";
0216 
0217    // hiddden state
0218    if (fAttrLayout != 0 || fNY.empty()) {
0219       out << "std::vector<" << fType << "> fVec_" << opName << "_hidden_state = std::vector<" << fType << ">(" << hs_size << ");\n";
0220    }
0221 
0222    out << "\n";
0223 
0224    return out.str();
0225 }
0226 
0227 
0228 template<typename T>
0229 auto ROperator_GRU<T>::Generate(std::string OpName)
0230 -> std::string {
0231    OpName = "op_" + OpName;
0232    std::stringstream out;
0233 
0234    size_t seq_length = (fAttrLayout == 0) ? fShapeX[0] : fShapeX[1];
0235    size_t batch_size = (fAttrLayout == 0) ? fShapeX[1] : fShapeX[0];
0236    size_t input_size = fShapeX[2];
0237    size_t num_directions = fShapeW[0];
0238 
0239    // set the input
0240    if (fAttrLayout == 0) {
0241       out << SP << fType << " *" << OpName << "_input = tensor_" << fNX << ";\n";
0242    } else {
0243       if (fUseSession) {
0244          out << SP << fType << " * " << OpName << "_input = fVec_" << OpName << "_input.data();\n";
0245       } else {
0246          out << SP << fType << " " << OpName << "_input[" << seq_length * batch_size * input_size << "];\n";
0247       }
0248       out << SP << "for(size_t seq = 0; seq < " << seq_length << "; seq++) {\n";
0249       out << SP << SP << "for(size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
0250       out << SP << SP << SP << "for(size_t i = 0; i < " << input_size << "; i++) {\n";
0251       out << SP << SP << SP << SP << OpName << "_input[seq * " << batch_size * input_size
0252           << " + batch * " << input_size << " + i] = " << "tensor_" << fNX << "[batch * "
0253           << seq_length * input_size << " + seq * " << input_size << " + i];\n";
0254       out << SP << SP << SP << "}\n";
0255       out << SP << SP << "}\n";
0256       out << SP << "}\n";
0257    }
0258 
0259    // Set the initial hidden state
0260    if (!fNInitial_h.empty()) {
0261       if (fAttrLayout == 0) {
0262          out << SP << fType << " *" << OpName << "_initial_hidden_state = " << " tensor_"
0263                 << fNInitial_h << ";\n";
0264       } else {
0265          if (fUseSession) {
0266             out << SP << fType << " * " << OpName << "_initial_hidden_state = fVec_" << OpName
0267                 << "_initial_hidden_state.data();\n";
0268          } else {
0269             out << SP << fType << " " << OpName << "_initial_hidden_state[" << num_directions * batch_size *
0270                 fAttrHiddenSize << "];\n";
0271          }
0272          for (size_t direction = 0; direction < num_directions; direction++) {
0273             out << SP << "for(size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
0274             out << SP << SP << "for(size_t h = 0; h < " << fAttrHiddenSize << "; h++) {\n";
0275             out << SP << SP << SP << OpName << "_initial_hidden_state["
0276                 << direction * batch_size * fAttrHiddenSize << " + batch * " << fAttrHiddenSize
0277                 << " + h] = tensor_" << fNInitial_h << "[batch * " << num_directions * fAttrHiddenSize
0278                 << " + " << direction * fAttrHiddenSize << " + h];\n";
0279             out << SP << SP << "}\n";
0280             out << SP << "}\n";
0281          }
0282       }
0283    }
0284 
0285    // Set the feedforward
0286    size_t feedforward_size = seq_length * batch_size * fAttrHiddenSize;
0287    if (fUseSession) {
0288       out << SP << fType << " * " << OpName << "_f_update_gate = fVec_" << OpName << "_f_update_gate.data();\n";
0289       out << SP << fType << " * " << OpName << "_f_reset_gate = fVec_" << OpName << "_f_reset_gate.data();\n";
0290       out << SP << fType << " * " << OpName << "_f_hidden_gate = fVec_" << OpName << "_f_hidden_gate.data();\n";
0291    } else {
0292       out << SP << fType << " " << OpName << "_f_update_gate[" << feedforward_size << "] = {0};\n";
0293       out << SP << fType << " " << OpName << "_f_reset_gate[" << feedforward_size << "] = {0};\n";
0294       out << SP << fType << " " << OpName << "_f_hidden_gate[" << feedforward_size << "] = {0};\n";
0295    }
0296    // Set the gates
0297    size_t hidden_state_size = seq_length * num_directions * batch_size * fAttrHiddenSize;
0298    if (fUseSession) {
0299       out << SP << fType << " * " << OpName << "_update_gate = fVec_" << OpName << "_update_gate.data();\n";
0300       out << SP << fType << " * " << OpName << "_reset_gate = fVec_" << OpName << "_reset_gate.data();\n";
0301       out << SP << fType << " * " << OpName << "_hidden_gate = fVec_" << OpName << "_hidden_gate.data();\n";
0302    } else {
0303       out << SP << fType << " " << OpName << "_update_gate[" << hidden_state_size << "] = {0};\n";
0304       out << SP << fType << " " << OpName << "_reset_gate[" << hidden_state_size << "] = {0};\n";
0305       out << SP << fType << " " << OpName << "_hidden_gate[" << hidden_state_size << "] = {0};\n";
0306    }
0307    // Set the hidden state
0308    if (fAttrLayout == 0 && !fNY.empty()) {
0309       out << SP << fType << " *" << OpName << "_hidden_state = tensor_" << fNY << ";\n";
0310    } else {
0311       if (fUseSession) {
0312          out << SP << fType << " * " << OpName << "_hidden_state = fVec_" << OpName << "_hidden_state.data();\n";
0313       } else {
0314          out << SP << fType << " " << OpName << "_hidden_state[" << hidden_state_size << "] = {0};\n";
0315       }
0316    }
0317 
0318    if (fUseSession) {
0319       out << SP << fType << " * " << OpName << "_feedback = fVec_" << OpName << "_feedback.data();\n";
0320    } else {
0321       out << SP << fType << " " << OpName << "_feedback[" << batch_size * fAttrHiddenSize << "] = {0};\n";
0322    }
0323 
0324    out << SP << "char " << OpName << "_transA = 'N';\n";
0325    out << SP << "char " << OpName << "_transB = 'T';\n";
0326    out << SP << "int " << OpName << "_m = " << seq_length * batch_size << ";\n";
0327    out << SP << "int " << OpName << "_m2 = " << batch_size << ";\n";
0328    out << SP << "int " << OpName << "_n = " << fAttrHiddenSize << ";\n";
0329    out << SP << "int " << OpName << "_k = " << input_size << ";\n";
0330    if (fType == "float") {
0331       out << SP << "float " << OpName << "_alpha = 1.;\n";
0332       out << SP << "float " << OpName << "_beta = 0.;\n";
0333    }
0334    if (!fNB.empty()) {
0335       out << SP << "int " << OpName << "_bias_size = " << seq_length * batch_size * fAttrHiddenSize << ";\n";
0336    }
0337    out << SP << "int " << OpName << "_incx = 1;\n";
0338    out << SP << "int " << OpName << "_incy = 1;\n";
0339    out << SP << "int " << OpName << "_feedback_size = " << batch_size * fAttrHiddenSize << ";\n";
0340 
0341    for (size_t direction = 0; direction < num_directions; direction++) {
0342       if (direction == 0) {
0343          if (fType == "float") {
0344             // f_update_gate = input * weight_z^T
0345             out << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
0346                 << OpName <<"_n, &" << OpName << "_m, &" << OpName << "_k, &" << OpName << "_alpha, tensor_"
0347                 << fNW << ", &" << OpName << "_k, " << OpName << "_input, &" << OpName << "_k, &"
0348                << OpName << "_beta, " << OpName << "_f_update_gate, &" << OpName << "_n);\n";
0349             // f_reset_gate = input * weight_r^T
0350             size_t wr_offset = fAttrHiddenSize * input_size;
0351             out << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
0352                 << OpName <<"_n, &" << OpName << "_m, &" << OpName << "_k, &" << OpName << "_alpha, tensor_"
0353                << fNW << " + " << wr_offset << ", &" << OpName << "_k, " << OpName << "_input, &"
0354                << OpName << "_k, &" << OpName << "_beta, " << OpName << "_f_reset_gate, &" << OpName << "_n);\n";
0355             // f_hidden_gate = input * weight_h^T
0356             size_t wh_offset = 2 * fAttrHiddenSize * input_size;
0357             out << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
0358                 << OpName <<"_n, &" << OpName << "_m, &" << OpName << "_k, &" << OpName << "_alpha, tensor_"
0359                << fNW << " + " << wh_offset << ", &" << OpName << "_k, " << OpName << "_input, &"
0360                << OpName << "_k, &" << OpName << "_beta, " << OpName << "_f_hidden_gate, &" << OpName << "_n);\n";
0361          }
0362       } else {
0363          if (fType == "float") {
0364             // f_update_gate = input * weight_z^T
0365             size_t wz_offset = 3 * fAttrHiddenSize * input_size;
0366             out << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
0367                 << OpName <<"_n, &" << OpName << "_m, &" << OpName << "_k, &" << OpName << "_alpha, tensor_"
0368                << fNW << " + " << wz_offset << ", &" << OpName << "_k, " << OpName << "_input, &"
0369                << OpName << "_k, &" << OpName << "_beta, " << OpName << "_f_update_gate, &" << OpName << "_n);\n";
0370             // f_reset_gate = input * weight_r^T
0371             size_t wr_offset = 3 * fAttrHiddenSize * input_size + fAttrHiddenSize * input_size;
0372             out << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
0373                 << OpName <<"_n, &" << OpName << "_m, &" << OpName << "_k, &" << OpName << "_alpha, tensor_"
0374                << fNW << " + " << wr_offset << ", &" << OpName << "_k, " << OpName << "_input, &"
0375                << OpName << "_k, &" << OpName << "_beta, " << OpName << "_f_reset_gate, &" << OpName << "_n);\n";
0376             // f_hidden_gate = input * weight_h^T
0377             size_t wh_offset = 3 * fAttrHiddenSize * input_size + 2 * fAttrHiddenSize * input_size;
0378             out << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
0379                 << OpName <<"_n, &" << OpName << "_m, &" << OpName << "_k, &" << OpName << "_alpha, tensor_"
0380                << fNW << " + " << wh_offset << ", &" << OpName << "_k, " << OpName << "_input, &"
0381                << OpName << "_k, &" << OpName << "_beta, " << OpName << "_f_hidden_gate, &" << OpName << "_n);\n";
0382          }
0383       }
0384 
0385       if (!fNB.empty()) {
0386          if (direction == 0) {
0387             if (fType == "float") {
0388                // Add the bias of the weight to f_update_gate
0389                out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_"
0390                    << fNB << ", &" << OpName << "_incx, " << OpName << "_f_update_gate, &" << OpName << "_incy);\n";
0391                // Add the bias of the recurrence to f_update_gate
0392                size_t rbz_offset = 3 * batch_size * seq_length * fAttrHiddenSize;
0393                out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_"
0394                    << fNB << " + " << rbz_offset << ", &" << OpName << "_incx, " << OpName << "_f_update_gate, &"
0395                    << OpName << "_incy);\n";
0396                // Add the bias of the weight to f_reset_gate
0397                size_t wbr_offset = batch_size * seq_length * fAttrHiddenSize;
0398                out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_"
0399                    << fNB << " + " << wbr_offset << ", &" << OpName << "_incx, " << OpName << "_f_reset_gate, &"
0400                    << OpName << "_incy);\n";
0401                // Add the bias of the recurrence to f_reset_gate
0402                //size_t rbr_offset = fAttrHiddenSize * fAttrHiddenSize + 3 * batch_size * fAttrHiddenSize;
0403                size_t rbr_offset = 4 * batch_size * seq_length * fAttrHiddenSize;
0404                out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_"
0405                    << fNB << " + " << rbr_offset << ", &" << OpName << "_incx, " << OpName << "_f_reset_gate, &"
0406                    << OpName << "_incy);\n";
0407                // Add the bias of the weight to f_hidden_gate
0408                size_t wbh_offset = 2 * batch_size * seq_length * fAttrHiddenSize;
0409                out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_"
0410                    << fNB << " + " << wbh_offset << ", &" << OpName << "_incx, " << OpName << "_f_hidden_gate, &"
0411                    << OpName << "_incy);\n";
0412                if (fAttrLinearBeforeReset == 0) {
0413                   // Add the bias of the recurrence to f_hidden_gate
0414                   size_t rbh_offset = 5 * batch_size * seq_length * fAttrHiddenSize;
0415                   out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_"
0416                       << fNB << " + " << rbh_offset << ", &" << OpName << "_incx, " << OpName
0417                       << "_f_hidden_gate, &" << OpName << "_incy);\n";
0418                }
0419             }
0420          } else {
0421             if (fType == "float") {
0422                // Add the bias of the weight to f_update_gate
0423                size_t wbz_offset = 6 * batch_size * seq_length * fAttrHiddenSize;
0424                out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_"
0425                    << fNB << " + " << wbz_offset << ", &" << OpName << "_incx, " << OpName << "_f_update_gate, &"
0426                    << OpName << "_incy);\n";
0427                // Add the bias of the recurrence to f_update_gate
0428                // size_t rbz_offset = 3 * fAttrHiddenSize * fAttrHiddenSize + 3 * batch_size * fAttrHiddenSize;
0429                size_t rbz_offset = 9 * batch_size * seq_length * fAttrHiddenSize;
0430                out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_"
0431                    << fNB << " + " << rbz_offset << ", &" << OpName << "_incx, " << OpName << "_f_update_gate, &"
0432                    << OpName << "_incy);\n";
0433                // Add the bias of the weight to f_reset_gate
0434                size_t wbr_offset =  7 * batch_size * seq_length * fAttrHiddenSize;
0435                out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_"
0436                    << fNB << " + " << wbr_offset << ", &" << OpName << "_incx, " << OpName << "_f_reset_gate, &"
0437                    << OpName << "_incy);\n";
0438                // Add the bias of the recurrence to f_reset_gate
0439                size_t rbr_offset = 10 * batch_size * seq_length * fAttrHiddenSize;
0440                out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_"
0441                    << fNB << " + " << rbr_offset << ", &" << OpName << "_incx, " << OpName << "_f_reset_gate, &"
0442                    << OpName << "_incy);\n";
0443                // Add the bias of the weight to f_hidden_gate
0444                size_t wbh_offset = 8 * batch_size * seq_length * fAttrHiddenSize;
0445                out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_"
0446                    << fNB << " + " << wbh_offset << ", &" << OpName << "_incx, " << OpName << "_f_hidden_gate, &"
0447                    << OpName << "_incy);\n";
0448                if (fAttrLinearBeforeReset == 0) {
0449                   // Add the bias of the recurrence to f_hidden_gate
0450                   size_t rbh_offset = 11 * batch_size * seq_length * fAttrHiddenSize;
0451                   out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_"
0452                       << fNB << " + " << rbh_offset << ", &" << OpName << "_incx, " << OpName
0453                       << "_f_hidden_gate, &" << OpName << "_incy);\n";
0454                }
0455             }
0456          }
0457       }
0458 
0459       // Copy the feedforward into the gates
0460       out << SP << "for (size_t seq = 0; seq < " << seq_length << "; seq++) {\n";
0461       out << SP << SP << "size_t offset = seq * " << batch_size * fAttrHiddenSize << ";\n";
0462       if (direction == 0) {
0463          out << SP << SP << "size_t gate_offset = seq * " << num_directions * batch_size * fAttrHiddenSize
0464             << ";\n";
0465       } else {
0466          out << SP << SP << "size_t gate_offset = seq * " << num_directions * batch_size * fAttrHiddenSize
0467              << " + " << batch_size * fAttrHiddenSize << ";\n";
0468       }
0469       size_t f_seq_size = batch_size * fAttrHiddenSize;
0470       out << SP << SP << "std::copy(" << OpName << "_f_update_gate + offset, " << OpName
0471           << "_f_update_gate + offset + " << f_seq_size << ", " << OpName << "_update_gate + gate_offset);\n";
0472       out << SP << SP << "std::copy(" << OpName << "_f_reset_gate + offset, " << OpName
0473           << "_f_reset_gate + offset + " << f_seq_size << ", " << OpName << "_reset_gate + gate_offset);\n";
0474       out << SP << SP << "std::copy(" << OpName << "_f_hidden_gate + offset, " << OpName
0475           << "_f_hidden_gate + offset + " << f_seq_size << ", " << OpName << "_hidden_gate + gate_offset);\n";
0476       out << SP << "}\n";
0477 
0478       out << SP << "for (size_t seq = 0; seq < " << seq_length << "; seq++) {\n";
0479       if (fAttrDirection == "backward" || direction == 1) {
0480          out << SP << SP << "size_t index = " << seq_length - 1 << " - seq;\n";
0481       } else {
0482          out << SP << SP << "size_t index = seq;\n";
0483       }
0484       out << SP << SP << "int m2 = " << batch_size << ";\n";
0485       if (direction == 0) {
0486          out << SP << SP << "size_t offset = index * " << num_directions * batch_size * fAttrHiddenSize
0487               << ";\n";
0488       } else {
0489          out << SP << SP << "size_t offset = index * " << num_directions * batch_size * fAttrHiddenSize
0490              << " + " << batch_size * fAttrHiddenSize << ";\n";
0491       }
0492       size_t size = batch_size * fAttrHiddenSize;
0493       // gate = gate + initial_hidden_state * Recurrence^T
0494       out << SP << SP << "if (seq == 0) {\n";
0495       if (!fNInitial_h.empty()) {
0496          if (direction == 0) {
0497             if (fType == "float") {
0498                out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
0499                    << OpName << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << ", &"
0500                    << OpName << "_n, " << OpName << "_initial_hidden_state, &" << OpName << "_n, &" << OpName
0501                    << "_alpha, " << OpName << "_update_gate + offset, &" << OpName << "_n);\n";
0502                size_t rr_offset = fAttrHiddenSize * fAttrHiddenSize;
0503                out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
0504                    << OpName << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << " + "
0505                    << rr_offset << ", &" << OpName << "_n, " << OpName << "_initial_hidden_state, &" << OpName
0506                    << "_n, &" << OpName << "_alpha, " << OpName << "_reset_gate + offset, &" << OpName << "_n);\n";
0507             }
0508          } else { // direction=1
0509             if (fType == "float") {
0510                size_t rz_offset = 3 * fAttrHiddenSize * fAttrHiddenSize;
0511                out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
0512                    << OpName << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << " + "
0513                    << rz_offset << ", &" << OpName << "_n, " << OpName << "_initial_hidden_state, &" << OpName
0514                    << "_n, &" << OpName << "_alpha, " << OpName << "_update_gate + offset, &" << OpName << "_n);\n";
0515                size_t rr_offset = 4 * fAttrHiddenSize * fAttrHiddenSize;
0516                out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
0517                    << OpName << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << " + "
0518                    << rr_offset << ", &" << OpName << "_n, " << OpName << "_initial_hidden_state, &" << OpName
0519                    << "_n, &" << OpName << "_alpha, " << OpName << "_reset_gate + offset, &" << OpName << "_n);\n";
0520             }
0521          }
0522       }
0523       out << SP << SP << "} else {\n";
0524       // gate = gate + previous_hidden_state * Recurrence^T
0525       if (direction == 0) {
0526          if (fAttrDirection == "backward") {
0527             out << SP << SP << SP << "size_t previous_offset = (index + 1) * "
0528                 << num_directions * batch_size * fAttrHiddenSize << ";\n";
0529          } else {
0530             out << SP << SP << SP << "size_t previous_offset = (seq - 1) * "
0531                 << num_directions * batch_size * fAttrHiddenSize << ";\n";
0532          }
0533          if (fType == "float") {
0534             out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
0535              << OpName << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << ", &"
0536              << OpName << "_n, " << OpName << "_hidden_state + previous_offset, &" << OpName << "_n, &"
0537              << OpName << "_alpha, " << OpName << "_update_gate + offset, &" << OpName << "_n);\n";
0538             size_t rr_offset = fAttrHiddenSize * fAttrHiddenSize;
0539             out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
0540              << OpName << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << " + "
0541              << rr_offset << ", &" << OpName << "_n, " << OpName << "_hidden_state + previous_offset, &"
0542              << OpName << "_n, &" << OpName << "_alpha, " << OpName << "_reset_gate + offset, &"
0543              << OpName << "_n);\n";
0544          }
0545       } else {
0546          out << SP << SP << SP << "size_t previous_offset = (index + 1) * "
0547              << num_directions * batch_size * fAttrHiddenSize << " + " << batch_size * fAttrHiddenSize << ";\n";
0548          if (fType == "float") {
0549             size_t rz_offset = 3 * fAttrHiddenSize * fAttrHiddenSize;
0550             out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
0551              << OpName << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << " + "
0552              << rz_offset << ", &" << OpName << "_n, " << OpName << "_hidden_state + previous_offset, &"
0553              << OpName << "_n, &" << OpName << "_alpha, " << OpName << "_update_gate + offset, &"
0554              << OpName << "_n);\n";
0555             size_t rr_offset = 4 * fAttrHiddenSize * fAttrHiddenSize;
0556             out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
0557              << OpName << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << " + "
0558              << rr_offset << ", &" << OpName << "_n, " << OpName << "_hidden_state + previous_offset, &"
0559              << OpName << "_n, &" << OpName << "_alpha, " << OpName << "_reset_gate + offset, &"
0560              << OpName << "_n);\n";
0561          }
0562       }
0563       out << SP << SP << "}\n";
0564 
0565       // Clip the elements of the update gate and the reset gate into the range [-fClip, fClip]
0566       if (fAttrClip > .0) {
0567          out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0568          if (fType == "float") {
0569             out << SP << SP << SP << "float z = (" << OpName << "_update_gate[i] > " << -fAttrClip
0570                 << ") ? " << OpName << "_update_gate[i] : " << -fAttrClip << ";\n";
0571          }
0572          out << SP << SP << SP << OpName << "_update_gate[i] = (z < " << fAttrClip
0573              << ") ? z : " << fAttrClip << ";\n";
0574          if (fType == "float") {
0575             out << SP << SP << SP << "float r = (" << OpName << "_reset_gate[i] > " << -fAttrClip
0576                 << ") ? " << OpName << "_reset_gate[i] : " << -fAttrClip << ";\n";
0577          }
0578          out << SP << SP << SP << OpName << "_reset_gate[i] = (r < " << fAttrClip
0579              << ") ? r : " << fAttrClip << ";\n";
0580          out << SP << SP << "}\n";
0581       }
0582 
0583       // Apply the activation function to the update gate and the reset gate
0584       if (fAttrActivations[direction * 2] == "Relu") {
0585          out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0586          out << SP << SP << SP << "if (" << OpName << "_update_gate[i] < 0.)\n";
0587          out << SP << SP << SP << SP << OpName << "_update_gate[i] = 0.;\n";
0588          out << SP << SP << SP << "if (" << OpName << "_reset_gate[i] < 0.)\n";
0589          out << SP << SP << SP << SP << OpName << "_reset_gate[i] = 0.;\n";
0590          out << SP << SP << "}\n";
0591       } else if (fAttrActivations[direction * 2] == "Tanh") {
0592          out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0593          if (fType == "float") {
0594             out << SP << SP << SP << "float z = exp(-2 * " << OpName << "_update_gate[i]);\n";
0595          }
0596          out << SP << SP << SP << SP << OpName << "_update_gate[i] = (1. - z) / (1. + z);\n";
0597          if (fType == "float") {
0598             out << SP << SP << SP << "float r = exp(-2 * " << OpName << "_reset_gate[i]);\n";
0599          }
0600          out << SP << SP << SP << SP << OpName << "_reset_gate[i] = (1. - r) / (1. + r);\n";
0601          out << SP << SP << "}\n";
0602       } else if (fAttrActivations[direction * 2] == "Sigmoid") {
0603          out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0604          out << SP << SP << SP << SP << OpName << "_update_gate[i] = 1. / (1. + exp(-"
0605              << OpName << "_update_gate[i]));\n";
0606          out << SP << SP << SP << SP << OpName << "_reset_gate[i] = 1. / (1. + exp(-"
0607              << OpName << "_reset_gate[i]));\n";
0608          out << SP << SP << "}\n";
0609       } else if (fAttrActivations[direction * 2] == "Affine") {
0610          out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0611          out << SP << SP << SP << SP << OpName << "_update_gate[i] = "
0612              << fAttrActivationAlpha[direction * 2] << " * " << OpName << "_update_gate[i] + "
0613              << fAttrActivationBeta[direction * 2] << ";\n";
0614          out << SP << SP << SP << SP << OpName << "_reset_gate[i] = "
0615              << fAttrActivationAlpha[direction * 2] << " * " << OpName << "_reset_gate[i] + "
0616              << fAttrActivationBeta[direction * 2] << ";\n";
0617          out << SP << SP << "}\n";
0618       } else if (fAttrActivations[direction * 2] == "ScaledTanh") {
0619          out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0620          if (fType == "float") {
0621             out << SP << SP << SP << "float z = exp(-2 * " << fAttrActivationBeta[direction * 2]
0622                 << " * "<< OpName << "_update_gate[i]);\n";
0623             }
0624             out << SP << SP << SP << SP << OpName << "_update_gate[i] = "
0625                 << fAttrActivationAlpha[direction * 2] << " * (1. - z) / (1. + z);\n";
0626          if (fType == "float") {
0627             out << SP << SP << SP << "float r = exp(-2 * " << fAttrActivationBeta[direction * 2]
0628                 << " * "<< OpName << "_reset_gate[i]);\n";
0629             }
0630             out << SP << SP << SP << SP << OpName << "_reset_gate[i] = "
0631                 << fAttrActivationAlpha[direction * 2] << " * (1. - r) / (1. + r);\n";
0632          out << SP << SP << "}\n";
0633       } else if (fAttrActivations[direction * 2] == "HardSigmoid") {
0634          out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0635          if (fType == "float") {
0636             out << SP << SP << SP << "float za = " << fAttrActivationAlpha[direction * 2] << " * "
0637                 << OpName << "_update_gate[i] + " << fAttrActivationBeta[direction * 2] << ";\n";
0638             out << SP << SP << SP << "float zb = (za > 0.) ? za : 0.;\n";
0639          }
0640          out << SP << SP << SP << SP << OpName << "_update_gate[i] = (zb < 1.) ? zb : 1.;\n";
0641          if (fType == "float") {
0642             out << SP << SP << SP << "float ra = " << fAttrActivationAlpha[direction * 2] << " * "
0643                 << OpName << "_reset_gate[i] + " << fAttrActivationBeta[direction * 2] << ";\n";
0644             out << SP << SP << SP << "float rb = (ra > 0.) ? ra : 0.;\n";
0645          }
0646          out << SP << SP << SP << SP << OpName << "_reset_gate[i] = (rb < 1.) ? rb : 1.;\n";
0647          out << SP << SP << "}\n";
0648       } else if (fAttrActivations[direction * 2] == "LeakyRelu") {
0649          out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0650          out << SP << SP << SP << "if (" << OpName << "_update_gate[i] < 0.)\n";
0651          out << SP << SP << SP << SP << OpName << "_update_gate[i] = "
0652              << fAttrActivationAlpha[direction * 2] << " * " << OpName << "_update_gate[i];\n";
0653          out << SP << SP << SP << "if (" << OpName << "_reset_gate[i] < 0.)\n";
0654          out << SP << SP << SP << SP << OpName << "_reset_gate[i] = "
0655              << fAttrActivationAlpha[direction * 2] << " * " << OpName << "_reset_gate[i];\n";
0656          out << SP << SP << "}\n";
0657       } else if (fAttrActivations[direction * 2] == "ThresholdRelu") {
0658          out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0659          out << SP << SP << SP << "if (" << OpName << "_update_gate[i] < "
0660              << fAttrActivationAlpha[direction * 2] << ")\n";
0661          out << SP << SP << SP << SP << OpName << "_update_gate[i] = 0.;\n";
0662          out << SP << SP << SP << "if (" << OpName << "_reset_gate[i] < "
0663              << fAttrActivationAlpha[direction * 2] << ")\n";
0664          out << SP << SP << SP << SP << OpName << "_reset_gate[i] = 0.;\n";
0665          out << SP << SP << "}";
0666       } else if (fAttrActivations[direction * 2] == "Elu") {
0667          out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0668          out << SP << SP << SP << "if (" << OpName << "_update_gate[i] < 0.)\n";
0669          out << SP << SP << SP << SP << OpName << "_update_gate[i] = "
0670              << fAttrActivationAlpha[direction * 2] << " * exp(" << OpName << "_update_gate[i] - 1.);\n";
0671          out << SP << SP << SP << "if (" << OpName << "_reset_gate[i] < 0.)\n";
0672          out << SP << SP << SP << SP << OpName << "_reset_gate[i] = "
0673              << fAttrActivationAlpha[direction * 2] << " * exp(" << OpName << "_reset_gate[i] - 1.);\n";
0674          out << SP << SP << "}\n";
0675       } else if (fAttrActivations[direction * 2] == "Softsign") {
0676          out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0677          out << SP << SP << SP << SP << OpName << "_update_gate[i] = " << OpName
0678              << "_update_gate[i] / (1. + abs(" << OpName << "_update_gate[i]));\n";
0679          out << SP << SP << SP << SP << OpName << "_reset_gate[i] = " << OpName
0680              << "_reset_gate[i] / (1. + abs(" << OpName << "_reset_gate[i]));\n";
0681          out << SP << SP << "}\n";
0682       } else { // fAttrActivations[direction * 2] = Softplus
0683          out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0684          out << SP << SP << SP << SP << OpName << "_update_gate[i] = log(1. + exp("
0685              << OpName << "_update_gate[i]));\n";
0686          out << SP << SP << SP << SP << OpName << "_reset_gate[i] = log(1. + exp("
0687              << OpName << "_reset_gate[i]));\n";
0688          out << SP << SP << "}\n";
0689       }
0690 
0691       if (fAttrLinearBeforeReset == 0) {
0692          out << SP << SP << "if (seq == 0) {\n";
0693          if (!fNInitial_h.empty()) {
0694             // feedback = reset_gate o initial_hidden_state
0695             out << SP << SP << SP << "for (size_t i = 0; i < " << size << "; i++) {\n";
0696             out << SP << SP << SP << SP << OpName << "_feedback[i] = " << OpName
0697                 << "_reset_gate[i + offset] * " << OpName << "_initial_hidden_state[i];\n";
0698             out << SP << SP << SP << "}\n";
0699          }
0700          out << SP << SP << "} else {\n";
0701          // feedback = reset_gate o previous_hidden_state
0702          if (direction == 0) {
0703             if (fAttrDirection == "backward") {
0704                out << SP << SP << SP << "size_t previous_offset = (index + 1) * "
0705                    << num_directions * batch_size * fAttrHiddenSize << ";\n";
0706             } else {
0707                out << SP << SP << SP << "size_t previous_offset = (seq - 1) * "
0708                    << num_directions * batch_size * fAttrHiddenSize << ";\n";
0709             }
0710          } else {
0711             out << SP << SP << SP << "size_t previous_offset = (index + 1) * " << num_directions
0712                 * batch_size * fAttrHiddenSize << " + " << batch_size * fAttrHiddenSize << ";\n";
0713          }
0714          out << SP << SP << SP << "for (size_t i = 0; i < " << size << "; i++) {\n";
0715          out << SP << SP << SP << SP << OpName << "_feedback[i] = " << OpName
0716              << "_reset_gate[i + offset] * " << OpName << "_hidden_state[i + previous_offset];\n";
0717          out << SP << SP << SP << "}\n";
0718          out << SP << SP << "}\n";
0719          // feedback = feedback * R_h^T
0720          size_t rh_offset = (direction == 0) ?
0721             2 * fAttrHiddenSize * fAttrHiddenSize : 3 * fAttrHiddenSize * fAttrHiddenSize
0722             + 2 * fAttrHiddenSize * fAttrHiddenSize;
0723          out << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
0724              << OpName << "_n, &" << OpName << "_m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_"
0725              << fNR << " + " << rh_offset << ", &" << OpName << "_n, " << OpName << "_feedback, &" << OpName
0726              << "_n, &" << OpName << "_beta, " << OpName << "_feedback, &" << OpName << "_n);\n";
0727       } else { // fAttrLinearBeforeReset=1
0728          // feedback = previous_hidden_state * R_h^T
0729          //LM fixes
0730          size_t rh_offset = (direction == 0)
0731                                ? 2 * fAttrHiddenSize * fAttrHiddenSize
0732                                : 3 * fAttrHiddenSize * fAttrHiddenSize + 2 * fAttrHiddenSize * fAttrHiddenSize;
0733          out << SP << SP << "if (seq == 0) {\n";
0734          if (!fNInitial_h.empty()) {
0735             // feedback = W * initial_hidden_state + bias
0736             out << SP << SP << SP
0737                << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &" << OpName << "_n, &"
0738                << OpName << "_m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << " + "
0739                << rh_offset << ", &" << OpName << "_n, " << OpName << "_initial_hidden_state, &"
0740                << OpName << "_n, &" << OpName << "_beta, " << OpName << "_feedback, &" << OpName << "_n);\n";
0741          }
0742          out << SP << SP << "} else {\n";
0743          // case for seq > 0
0744          if (direction == 0) {
0745             if (fAttrDirection == "backward") {
0746                out << SP << SP << SP << "size_t previous_offset = (index + 1) * "
0747                   << num_directions * batch_size * fAttrHiddenSize << ";\n";
0748             } else {
0749                out << SP << SP  << SP << "size_t previous_offset = (seq - 1) * "
0750                   << num_directions * batch_size * fAttrHiddenSize << ";\n";
0751             }
0752          } else {
0753             out << SP << SP <<  SP << "size_t previous_offset = (index + 1) * " << num_directions
0754                 * batch_size * fAttrHiddenSize << " + " << batch_size * fAttrHiddenSize << ";\n";
0755          }
0756          out << SP << SP  << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
0757              << OpName << "_n, &" << OpName << "_m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR
0758              << " + " << rh_offset << ", &" << OpName << "_n, " << OpName << "_hidden_state + previous_offset, &"
0759              << OpName << "_n, &" << OpName << "_beta, " << OpName << "_feedback, &" << OpName << "_n);\n";
0760          // endif on seq 0 or not
0761          out << SP << SP << "}\n";
0762          // Add the bias of the recurrence to feedback
0763          if (!fNB.empty()) {
0764             size_t rbh_offset = (direction == 0) ? 5 * batch_size * seq_length * fAttrHiddenSize
0765                                                  : 11 * batch_size * seq_length * fAttrHiddenSize;
0766             out << SP << SP << "BLAS::saxpy_(&" << OpName << "_feedback_size, &" << OpName
0767              << "_alpha, tensor_" << fNB << " + " << rbh_offset << ", &" << OpName << "_incx, "
0768              << OpName << "_feedback, &" << OpName << "_incy);\n";
0769          }
0770          // feedback = reset_gate o feedback
0771          out << SP << SP << "for (size_t i = 0; i < " << size << "; i++) {\n";
0772          out << SP << SP << SP << OpName << "_feedback[i] *= " << OpName << "_reset_gate[i + offset];\n";
0773          out << SP << SP << "}\n";
0774       }
0775 
0776       // hidden_gate = hidden_gate + feedback
0777       out << SP << SP << "BLAS::saxpy_(&" << OpName << "_feedback_size, &" << OpName << "_alpha, "
0778           << OpName << "_feedback, &" << OpName << "_incx, " << OpName << "_hidden_gate + offset, &"
0779           << OpName << "_incy);\n";
0780 
0781       // Clip the elements of the hidden gate into the range [-fClip, fClip]
0782       if (fAttrClip > .0) {
0783          out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0784          if (fType == "float") {
0785             out << SP << SP << SP << "float x = (" << OpName << "_hidden_gate[i] > " << -fAttrClip
0786                 << ") ? " << OpName << "_hidden_gate[i] : " << -fAttrClip << ";\n";
0787          }
0788          out << SP << SP << SP << OpName << "_hidden_gate[i] = (x < " << fAttrClip << ") ? x : "
0789              << fAttrClip << ";\n";
0790          out << SP << SP << "}\n";
0791       }
0792 
0793       // Apply the activation function to the hidden gate
0794       if (fAttrActivations[direction * 2 + 1] == "Relu") {
0795          out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0796          out << SP << SP << SP << "if (" << OpName << "_hidden_gate[i] < 0.)\n";
0797          out << SP << SP << SP << SP << OpName << "_hidden_gate[i] = 0.;\n";
0798          out << SP << SP << "}\n";
0799       } else if (fAttrActivations[direction * 2 + 1] == "Tanh") {
0800          out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0801          if (fType == "float") {
0802             out << SP << SP << SP << "float ex = exp(-2 * " << OpName << "_hidden_gate[i]);\n";
0803          }
0804          out << SP << SP << SP << SP << OpName << "_hidden_gate[i] = (1. - ex) / (1. + ex);\n";
0805          out << SP << SP << "}\n";
0806       } else if (fAttrActivations[direction * 2 + 1] == "Sigmoid") {
0807          out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0808          out << SP << SP << SP << SP << OpName << "_hidden_gate[i] = 1. / (1. + exp(-" << OpName
0809              << "_hidden_gate[i]));\n";
0810          out << SP << SP << "}\n";
0811       } else if (fAttrActivations[direction * 2 + 1] == "Affine") {
0812          out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0813          out << SP << SP << SP << SP << OpName << "_hidden_gate[i] = "
0814              << fAttrActivationAlpha[direction * 2 + 1] << " * " << OpName << "_hidden_gate[i] + "
0815              << fAttrActivationBeta[direction * 2 + 1] << ";\n";
0816          out << SP << SP << "}\n";
0817       } else if (fAttrActivations[direction * 2 + 1] == "ScaledTanh") {
0818          out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0819          if (fType == "float") {
0820             out << SP << SP << SP << "float ex = exp(-2 * " << fAttrActivationBeta[direction * 2 + 1]
0821                 << " * "<< OpName << "_hidden_gate[i]);\n";
0822             }
0823             out << SP << SP << SP << SP << OpName << "_hidden_gate[i] = "
0824                 << fAttrActivationAlpha[direction * 2 + 1] << " * (1. - ex) / (1. + ex);\n";
0825          out << SP << SP << "}\n";
0826       } else if (fAttrActivations[direction * 2 + 1] == "HardSigmoid") {
0827          out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0828          if (fType == "float") {
0829             out << SP << SP << SP << "float a = " << fAttrActivationAlpha[direction * 2 + 1] << " * "
0830                 << OpName << "_hidden_gate[i] + " << fAttrActivationBeta[direction * 2 + 1] << ";\n";
0831             out << SP << SP << SP << "float b = (a > 0.) ? a : 0.;\n";
0832          }
0833          out << SP << SP << SP << SP << OpName << "_hidden_gate[i] = (b < 1.) ? b : 1.;\n";
0834          out << SP << SP << "}\n";
0835       } else if (fAttrActivations[direction * 2 + 1] == "LeakyRelu") {
0836          out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0837          out << SP << SP << SP << "if (" << OpName << "_hidden_gate[i] < 0.)\n";
0838          out << SP << SP << SP << SP << OpName << "_hidden_gate[i] = "
0839              << fAttrActivationAlpha[direction * 2 + 1] << " * " << OpName << "_hidden_gate[i];\n";
0840          out << SP << SP << "}\n";
0841       } else if (fAttrActivations[direction * 2 + 1] == "ThresholdRelu") {
0842          out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0843          out << SP << SP << SP << "if (" << OpName << "_hidden_gate[i] < "
0844              << fAttrActivationAlpha[direction * 2 + 1] << ")\n";
0845          out << SP << SP << SP << SP << OpName << "_hidden_gate[i] = 0.;\n";
0846          out << SP << SP << "}";
0847       } else if (fAttrActivations[direction * 2 + 1] == "Elu") {
0848          out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0849          out << SP << SP << SP << "if (" << OpName << "_hidden_gate[i] < 0.)\n";
0850          out << SP << SP << SP << SP << OpName << "_hidden_gate[i] = "
0851              << fAttrActivationAlpha[direction * 2 + 1] << " * exp(" << OpName << "_hidden_gate[i] - 1.);\n";
0852          out << SP << SP << "}\n";
0853       } else if (fAttrActivations[direction * 2 + 1] == "Softsign") {
0854          out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0855          out << SP << SP << SP << SP << OpName << "_hidden_gate[i] = " << OpName
0856              << "_hidden_gate[i] / (1. + abs(" << OpName << "_hidden_gate[i]));\n";
0857          out << SP << SP << "}\n";
0858       } else { // fAttrActivations[direction * 2 + 1] = Softplus
0859          out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0860          out << SP << SP << SP << SP << OpName << "_hidden_gate[i] = log(1. + exp("
0861              << OpName << "_hidden_gate[i]));\n";
0862          out << SP << SP << "}\n";
0863       }
0864 
0865       // hidden_state = (1 - update_gate) o hidden_gate
0866       out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0867       out << SP << SP << SP << OpName << "_hidden_state[i] = ( 1. - " << OpName
0868           << "_update_gate[i]) * " << OpName << "_hidden_gate[i];\n";
0869       out << SP << SP << "}\n";
0870 
0871       out << SP << SP << "if (seq == 0) {\n";
0872       if (!fNInitial_h.empty()) {
0873          // hidden_state += update_gate o initial_hidden_state
0874          out << SP << SP << SP << "for (size_t i = 0; i < " << size << "; i++) {\n";
0875          out << SP << SP << SP << SP << OpName << "_hidden_state[i + offset] += " << OpName
0876              << "_update_gate[i + offset] * " << OpName << "_initial_hidden_state[i];\n";
0877          out << SP << SP << SP << "}\n";
0878       }
0879       out << SP << SP << "} else {\n";
0880       // hidden_state += update_gate o previous_hidden_state
0881       if (direction == 0) {
0882          if (fAttrDirection == "backward") {
0883             out << SP << SP << SP << "size_t previous_offset = (index + 1) * "
0884                 << num_directions * batch_size * fAttrHiddenSize << ";\n";
0885          } else {
0886             out << SP << SP << SP << "size_t previous_offset = (seq - 1) * "
0887                 << num_directions * batch_size * fAttrHiddenSize << ";\n";
0888          }
0889       } else {
0890          out << SP << SP << SP << "size_t previous_offset = (index + 1) * "
0891              << num_directions * batch_size * fAttrHiddenSize << " + " << batch_size * fAttrHiddenSize << ";\n";
0892       }
0893       out << SP << SP << SP << "for (size_t i = 0; i < " << size << "; i++) {\n";
0894       out << SP << SP << SP << SP << OpName << "_hidden_state[i + offset] += " << OpName
0895           << "_update_gate[i + offset] * " << OpName << "_hidden_state[i + previous_offset];\n";
0896       out << SP << SP << SP << "}\n";
0897       out << SP << SP << "}\n";
0898 
0899       out << SP << "}\n";
0900    }
0901 
0902    // Padding the hidden state for GRU with different sequence lengths
0903    if (!fNSequence_lens.empty()) {
0904       out << SP << "for (size_t seq = 0; seq < " << seq_length << "; seq++) {\n";
0905       out << SP << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
0906       out << SP << SP << SP << "if (seq >= tensor_" << fNSequence_lens << "[batch]) {\n";
0907       for (size_t direction = 0; direction < num_directions; direction++) {
0908          out << SP << SP << SP << SP << SP << "for (size_t h = 0; h < " << fAttrHiddenSize << "; h++) {\n";
0909          out << SP << SP << SP << SP << SP << SP << OpName << "_hidden_state[seq * "
0910              << num_directions * batch_size * fAttrHiddenSize + direction * batch_size * fAttrHiddenSize
0911              << " + batch * " << fAttrHiddenSize << " + h] = 0.;\n";
0912          out << SP << SP << SP << SP << SP << "}\n";
0913       }
0914       out << SP << SP << SP << "}\n";
0915       out << SP << SP << "}\n";
0916       out << SP << "}\n";
0917    }
0918 
0919    // Copy the hidden state into y and y_h
0920    if (fAttrLayout == 0) {
0921       if (!fNY_h.empty()) {
0922          // Copy hidden_state into Y_h
0923          if (fNSequence_lens.empty()) {
0924             size_t yh_size = batch_size * fAttrHiddenSize;
0925             if (fAttrDirection == "backward") {
0926                out << SP << "std::copy(" << OpName << "_hidden_state, " << OpName << "_hidden_state + "
0927                    << yh_size << ", tensor_" << fNY_h << ");\n";
0928             } else {
0929                size_t offset = (seq_length - 1) * num_directions * batch_size * fAttrHiddenSize;
0930                out << SP << "std::copy(" << OpName << "_hidden_state + " << offset << ", " << OpName
0931                    << "_hidden_state + " << offset << " + " << yh_size << ", tensor_" << fNY_h << ");\n";
0932             }
0933             if (num_directions == 2) {
0934                out << SP << "std::copy(" << OpName << "_hidden_state + " << yh_size << ", " << OpName
0935                    << "_hidden_state + " << 2 * yh_size << ", tensor_" << fNY_h << " + " << yh_size << ");\n";
0936             }
0937          } else { // GRU with different sequence lengths
0938             if (fAttrDirection == "backward") {
0939                out << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
0940                out << SP << SP << "size_t offset = batch * " << fAttrHiddenSize << ";\n";
0941                out << SP << SP << "std::copy(" << OpName << "_hidden_state + offset, " << OpName
0942                    << "_hidden_state + offset + " << fAttrHiddenSize << ", tensor_" << fNY_h << " + offset);\n";
0943                out << SP << "}\n";
0944             } else {
0945                out << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
0946                out << SP << SP << "size_t seq = " << "tensor_" << fNSequence_lens << "[batch] - 1;\n";
0947                out << SP << SP << "size_t offset = seq * " << num_directions * batch_size * fAttrHiddenSize
0948                    << " + batch * " << fAttrHiddenSize << ";\n";
0949                out << SP << SP << "size_t yh_offset = batch * " << fAttrHiddenSize << ";\n";
0950                out << SP << SP << "std::copy(" << OpName << "_hidden_state + offset, " << OpName
0951                    << "_hidden_state + offset + " << fAttrHiddenSize << ", tensor_" << fNY_h << " + yh_offset);\n";
0952                out << SP << "}\n";
0953             }
0954             if (num_directions == 2) {
0955                out << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
0956                out << SP << SP << "size_t offset = " << batch_size * fAttrHiddenSize
0957                    << " + batch * " << fAttrHiddenSize << ";\n";
0958                out << SP << SP << "size_t yh_offset = " << batch_size * fAttrHiddenSize
0959                    << " + batch * " << fAttrHiddenSize << ";\n";
0960                out << SP << SP << "std::copy(" << OpName << "_hidden_state + offset, " << OpName
0961                    << "_hidden_state + offset + " << fAttrHiddenSize << ", tensor_" << fNY_h << " + yh_offset);\n";
0962                out << SP << "}\n";
0963             }
0964          }
0965       }
0966    } else { // fAttrLayout=1
0967       if (!fNY.empty()) {
0968          // Copy hidden_state into Y
0969          for (size_t direction = 0; direction < num_directions; direction++) {
0970             out << SP << "for (size_t seq = 0; seq < " << seq_length << "; seq++) {\n";
0971             out << SP << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
0972             out << SP << SP << SP << "size_t offset = seq * " << num_directions * batch_size * fAttrHiddenSize
0973                 << " + " << direction * batch_size * fAttrHiddenSize << " + batch * " << fAttrHiddenSize << ";\n";
0974             out << SP << SP << SP << "size_t y_offset = batch * " << seq_length * num_directions * fAttrHiddenSize
0975                 << " + seq * " << num_directions * fAttrHiddenSize << " + " << direction * fAttrHiddenSize << ";\n";
0976             out << SP << SP << SP << "std::copy(" << OpName << "_hidden_state + offset, " << OpName
0977                 << "_hidden_state + offset + " << fAttrHiddenSize << ", tensor_" << fNY << " + y_offset);\n";
0978             out << SP << SP << "}\n";
0979             out << SP << "}\n";
0980          }
0981       }
0982       if (!fNY_h.empty()) {
0983          // Copy the hidden_state into Y_h
0984          if (fAttrDirection == "backward") {
0985             out << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
0986             out << SP << SP << "size_t offset = batch * " << fAttrHiddenSize << ";\n";
0987             out << SP << SP << "size_t yh_offset = batch * " << num_directions * fAttrHiddenSize << ";\n";
0988             out << SP << SP << "std::copy(" << OpName << "_hidden_state + offset, " << OpName
0989                 << "_hidden_state + offset + " << fAttrHiddenSize << ", tensor_" << fNY_h << " + yh_offset);\n";
0990             out << SP << "}\n";
0991          } else {
0992             out << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
0993             if (fNSequence_lens.empty()) {
0994                out << SP << SP << "size_t seq = " << seq_length - 1 << ";\n";
0995             } else {
0996                out << SP << SP << "size_t seq = " << "tensor_" << fNSequence_lens << "[batch] - 1;\n";
0997             }
0998             out << SP << SP << "size_t offset = seq * " << num_directions * batch_size * fAttrHiddenSize
0999                 << " + batch * " << fAttrHiddenSize << ";\n";
1000             out << SP << SP << "size_t yh_offset = batch * " << num_directions * fAttrHiddenSize << ";\n";
1001             out << SP << SP << "std::copy(" << OpName << "_hidden_state + offset, " << OpName
1002                 << "_hidden_state + offset + " << fAttrHiddenSize << ", tensor_" << fNY_h << " + yh_offset);\n";
1003             out << SP << "}\n";
1004          }
1005          if (num_directions == 2) {
1006             out << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
1007             out << SP << SP << "size_t offset = " << batch_size * fAttrHiddenSize << " + batch * "
1008                 << fAttrHiddenSize << ";\n";
1009             out << SP << SP << "size_t yh_offset = batch * " << num_directions * fAttrHiddenSize << " + "
1010                 << fAttrHiddenSize << ";\n";
1011             out << SP << SP << "std::copy(" << OpName << "_hidden_state + offset, " << OpName
1012                 << "_hidden_state + offset + " << fAttrHiddenSize << ", tensor_" << fNY_h << " + yh_offset);\n";
1013             out << SP << "}\n";
1014          }
1015       }
1016    }
1017 
1018    return out.str();
1019 }
1020 
1021 } // namespace SOFIE
1022 } // namespace Experimental
1023 } // namespace TMVA
1024 
1025 #endif