Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-18 10:11:06

0001 #ifndef TMVA_SOFIE_ROPERATOR_GRU_I
0002 #define TMVA_SOFIE_ROPERATOR_GRU_I
0003 
0004 namespace TMVA {
0005 namespace Experimental {
0006 namespace SOFIE {
0007 
0008 template <typename T>
0009 auto ROperator_GRU<T>::TypeInference(std::vector<ETensorType> input)
0010 -> std::vector<ETensorType> {
0011     ETensorType out = input[0];
0012     return {out, out};
0013 }
0014 
0015 template<typename T>
0016 auto ROperator_GRU<T>::ShapeInference(std::vector<std::vector<size_t>> input)
0017 -> std::vector<std::vector<size_t>> {
0018     size_t num_directions = input[1][0];
0019     size_t hidden_size = input[1][1] / 3;
0020     if (fAttrLayout == 0) {
0021         size_t seq_length = input[0][0];
0022         size_t batch_size = input[0][1];
0023         std::vector<std::vector<size_t>> ret(
0024             {{seq_length, num_directions, batch_size, hidden_size},
0025             {num_directions, batch_size, hidden_size}});
0026         return ret;
0027     } else {
0028         size_t batch_size = input[0][0];
0029         size_t seq_length = input[0][1];
0030         std::vector<std::vector<size_t>> ret(
0031             {{batch_size, seq_length, num_directions, hidden_size},
0032             {batch_size, num_directions, hidden_size}});
0033         return ret;
0034     }
0035 }
0036 
0037 template<typename T>
0038 void ROperator_GRU<T>::Initialize(RModel &model) {
0039    fUseSession = model.UseSession();
0040    // Check the input and output tensors
0041    if (!model.CheckIfTensorAlreadyExist(fNX)) {
0042       throw std::runtime_error("TMVA SOFIE GRU Op input tensor " + fNX + "  is not found in model.");
0043    }
0044    fShapeX = model.GetTensorShape(fNX);
0045    if (fShapeX.size() != 3) {
0046       throw std::runtime_error("TMVA SOFIE GRU Op input tensor " + fNX + " is not of 3 dimensions.");
0047    }
0048    if (!model.CheckIfTensorAlreadyExist(fNW)) {
0049       throw std::runtime_error("TMVA SOFIE GRU Op input tensor " + fNW + "  is not found in model.");
0050    }
0051    fShapeW = model.GetTensorShape(fNW);
0052    if (fShapeW.size() != 3) {
0053       throw std::runtime_error("TMVA SOFIE GRU Op input tensor " + fNW + " is not of 3 dimensions.");
0054    }
0055    if (!model.CheckIfTensorAlreadyExist(fNR)) {
0056       throw std::runtime_error("TMVA SOFIE GRU Op input tensor " + fNR + "  is not found in model.");
0057    }
0058    fShapeR = model.GetTensorShape(fNR);
0059    if (fShapeR.size() != 3) {
0060       throw std::runtime_error("TMVA SOFIE GRU Op input tensor " + fNR + " is not of 3 dimensions.");
0061    }
0062    if (!fNB.empty()) {
0063       if (!model.CheckIfTensorAlreadyExist(fNB)) {
0064          throw std::runtime_error("TMVA SOFIE GRU op input tensor " + fNB + " is not  found in model.");
0065       }
0066       fShapeB = model.GetTensorShape(fNB);
0067       if (fShapeB.size() != 2 && fShapeB.size() != 4) {
0068          throw std::runtime_error("TMVA SOFIE GRU op input tensor " + fNB + " is not of 2 or 4 dimensions.");
0069       }
0070       if (fShapeB.size() == 2) {
0071          // Broadcasting the bias
0072          auto original_data = model.GetInitializedTensorData(fNB);
0073          size_t num_directions = fShapeW[0];
0074          size_t batch_size = (fAttrLayout == 0) ? fShapeX[1] : fShapeX[0];
0075          size_t seq_length = (fAttrLayout == 0) ? fShapeX[0] : fShapeX[1];
0076          if (fType == "float") {
0077             float *original_bias = static_cast<float*>(original_data.get());
0078             float *new_bias = new float[num_directions * 6 * seq_length * batch_size * fAttrHiddenSize];
0079             for (size_t direction = 0; direction < num_directions; direction++) {
0080                for (size_t i = 0; i < 6; i++) {
0081                   for (size_t seq = 0; seq < seq_length; seq++) {
0082                      for (size_t batch = 0; batch < batch_size; batch++) {
0083                         size_t bias_offset = direction * 6 * fAttrHiddenSize + i * fAttrHiddenSize;
0084                         size_t offset = direction * 6 * batch_size * seq_length * fAttrHiddenSize +
0085                                        i * batch_size * seq_length * fAttrHiddenSize +
0086                                        + seq *batch_size *fAttrHiddenSize + batch *fAttrHiddenSize;
0087                         std::copy(original_bias + bias_offset, original_bias + bias_offset + fAttrHiddenSize,
0088                                   new_bias + offset);
0089                      }
0090                   }
0091                }
0092             }
0093 
0094             std::vector<size_t> new_bias_shape = {num_directions, 6, seq_length, batch_size, fAttrHiddenSize};
0095             std::shared_ptr<void> new_bias_ptr(new_bias, std::default_delete<float[]>());
0096             model.UpdateInitializedTensor(fNB, model.GetTensorType(fNB), new_bias_shape, new_bias_ptr);
0097             fShapeB = model.GetTensorShape(fNB);
0098          }
0099       }
0100    }
0101    if (!fNSequence_lens.empty()) {
0102       if (!model.CheckIfTensorAlreadyExist(fNSequence_lens)) {
0103          throw std::runtime_error("TMVA SOFIE GRU Op input tensor " +
0104                                   fNSequence_lens +
0105                                   "is not found in model.");
0106       }
0107       fShapeSequence_lens = model.GetTensorShape(fNSequence_lens);
0108       if (fShapeSequence_lens.size() != 1) {
0109          throw std::runtime_error("TMVA SOFIE GRU Op input tensor " +
0110                                   fNSequence_lens +
0111                                   " is not of 1 dimension.");
0112       }
0113    }
0114    if (!fNInitial_h.empty()) {
0115       if (!model.CheckIfTensorAlreadyExist(fNInitial_h)) {
0116          throw std::runtime_error("TMVA SOFIE GRU Op input tensor " +
0117                                   fNInitial_h + " is not found in model.");
0118       }
0119       fShapeInitial_h = model.GetTensorShape(fNInitial_h);
0120       if (fShapeInitial_h.size() != 3) {
0121          throw std::runtime_error("TMVA SOFIE GRU Op input tensor " +
0122                                   fNInitial_h + " is not of 3 dimensions.");
0123       }
0124    }
0125    if (!fNY.empty()) {
0126       fShapeY = ShapeInference({fShapeX, fShapeW})[0];
0127       if (!model.CheckIfTensorAlreadyExist(fNY)) {
0128          model.AddIntermediateTensor(fNY, model.GetTensorType(fNX), fShapeY);
0129       }
0130    }
0131    if (!fNY_h.empty()) {
0132       fShapeY_h = ShapeInference({fShapeX, fShapeW})[1];
0133       if (!model.CheckIfTensorAlreadyExist(fNY_h)) {
0134          model.AddIntermediateTensor(fNY_h, model.GetTensorType(fNX), fShapeY_h);
0135       }
0136    }
0137    // Check the attributes
0138    for (auto &activation : fAttrActivations) {
0139       if (activation != "Relu" && activation != "Tanh" &&
0140           activation != "Sigmoid" && activation != "Affine" &&
0141           activation != "LeakyRelu" && activation != "ThresholdRelu" &&
0142           activation != "ScaledTanh" && activation != "HardSigmoid" &&
0143           activation != "Elu" && activation != "Softsign" &&
0144           activation != "Softplus") {
0145          throw std::runtime_error("TMVA SOFIE - Activation function " +
0146                                   activation + " not implemented");
0147       }
0148    }
0149    if (fAttrDirection == "reverse") fAttrDirection = "backward";
0150    if (fAttrDirection != "forward" && fAttrDirection != "backward" &&
0151        fAttrDirection != "reverse" &&
0152        fAttrDirection != "bidirectional") {
0153       throw std::runtime_error(
0154           "TMVA SOFIE - Invalid GRU direction fAttrDirection = " +
0155           fAttrDirection);
0156    }
0157    if (3 * fAttrHiddenSize != fShapeW[1]) {
0158       throw std::runtime_error(
0159           "TMVA SOFIE - fAttrHiddenSize must be equal to " +
0160           std::to_string(fShapeW[1] / 3));
0161    }
0162    if (fAttrLayout > 1) {
0163       throw std::runtime_error("TMVA SOFIE - Layout fAttrLayout = " +
0164                                std::to_string(fAttrLayout) +
0165                                " must be 0 (timewise) or 1 (batchwise)");
0166    }
0167    if (fAttrLinearBeforeReset > 1) {
0168       throw std::runtime_error(
0169          "TMVA SOFIE - fAttrInputForget = " + std::to_string(fAttrLinearBeforeReset)
0170          + " must be 0 or 1.");
0171    }
0172    if (fAttrActivations.empty()) {
0173       if (fAttrDirection == "bidirectional") {
0174          fAttrActivations = {"Sigmoid", "Tanh", "Sigmoid", "Tanh"};
0175       } else {
0176          fAttrActivations = {"Sigmoid", "Tanh"};
0177       }
0178    }
0179 }
0180 
0181 // generate code for Session data members (e.g. internal vectors)
0182 template <typename T>
0183 std::string ROperator_GRU<T>::GenerateSessionMembersCode(std::string opName)
0184 {
0185    opName = "op_" + opName;
0186    std::stringstream out;
0187 
0188    size_t num_directions = fShapeW[0];
0189    size_t seq_length = (fAttrLayout == 0) ? fShapeX[0] : fShapeX[1];
0190    size_t batch_size = (fAttrLayout == 0) ? fShapeX[1] : fShapeX[0];
0191    size_t input_size = fShapeX[2];
0192 
0193    if (fAttrLayout != 0) {
0194       out << "std::vector<" << fType << "> fVec_" << opName << "_input = std::vector<" << fType << ">("
0195           << seq_length * batch_size * input_size << ");\n";
0196       out << "std::vector<" << fType << "> fVec_" << opName << "_initial_hidden_state = std::vector<" << fType << ">("
0197           << num_directions * batch_size * fAttrHiddenSize << ");\n";
0198       out << "std::vector<" << fType << "> fVec_" << opName << "_initial_cell_state = std::vector<" << fType << ">("
0199           << num_directions * batch_size * fAttrHiddenSize << ");\n";
0200    }
0201    // Set the feedforward
0202    size_t ff_size = seq_length * batch_size * fAttrHiddenSize;
0203    out << "std::vector<" << fType << "> fVec_" << opName << "_f_update_gate = std::vector<" << fType << ">(" << ff_size << ");\n";
0204    out << "std::vector<" << fType << "> fVec_" << opName << "_f_reset_gate = std::vector<" << fType << ">(" << ff_size << ");\n";
0205    out << "std::vector<" << fType << "> fVec_" << opName << "_f_hidden_gate = std::vector<" << fType << ">(" << ff_size << ");\n";
0206    // gate results
0207    size_t hs_size = seq_length * num_directions * batch_size * fAttrHiddenSize;
0208    out << "std::vector<" << fType << "> fVec_" << opName << "_update_gate = std::vector<" << fType << ">(" << hs_size << ");\n";
0209    out << "std::vector<" << fType << "> fVec_" << opName << "_reset_gate = std::vector<" << fType << ">(" << hs_size << ");\n";
0210    out << "std::vector<" << fType << "> fVec_" << opName << "_hidden_gate = std::vector<" << fType << ">(" << hs_size << ");\n";
0211 
0212    // feedback
0213    out << "std::vector<" << fType << "> fVec_" << opName << "_feedback = std::vector<" << fType << ">("
0214        << batch_size * fAttrHiddenSize << ");\n";
0215 
0216    // hiddden state
0217    if (fAttrLayout != 0 || fNY.empty()) {
0218       out << "std::vector<" << fType << "> fVec_" << opName << "_hidden_state = std::vector<" << fType << ">(" << hs_size << ");\n";
0219    }
0220 
0221    out << "\n";
0222 
0223    return out.str();
0224 }
0225 
0226 
0227 template<typename T>
0228 auto ROperator_GRU<T>::Generate(std::string OpName)
0229 -> std::string {
0230    OpName = "op_" + OpName;
0231    std::stringstream out;
0232 
0233    size_t seq_length = (fAttrLayout == 0) ? fShapeX[0] : fShapeX[1];
0234    size_t batch_size = (fAttrLayout == 0) ? fShapeX[1] : fShapeX[0];
0235    size_t input_size = fShapeX[2];
0236    size_t num_directions = fShapeW[0];
0237 
0238    // set the input
0239    if (fAttrLayout == 0) {
0240       out << SP << fType << " *" << OpName << "_input = tensor_" << fNX << ";\n";
0241    } else {
0242       if (fUseSession) {
0243          out << SP << fType << " * " << OpName << "_input = fVec_" << OpName << "_input.data();\n";
0244       } else {
0245          out << SP << fType << " " << OpName << "_input[" << seq_length * batch_size * input_size << "];\n";
0246       }
0247       out << SP << "for(size_t seq = 0; seq < " << seq_length << "; seq++) {\n";
0248       out << SP << SP << "for(size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
0249       out << SP << SP << SP << "for(size_t i = 0; i < " << input_size << "; i++) {\n";
0250       out << SP << SP << SP << SP << OpName << "_input[seq * " << batch_size * input_size
0251           << " + batch * " << input_size << " + i] = " << "tensor_" << fNX << "[batch * "
0252           << seq_length * input_size << " + seq * " << input_size << " + i];\n";
0253       out << SP << SP << SP << "}\n";
0254       out << SP << SP << "}\n";
0255       out << SP << "}\n";
0256    }
0257 
0258    // Set the initial hidden state
0259    if (!fNInitial_h.empty()) {
0260       if (fAttrLayout == 0) {
0261          out << SP << fType << " *" << OpName << "_initial_hidden_state = " << " tensor_"
0262                 << fNInitial_h << ";\n";
0263       } else {
0264          if (fUseSession) {
0265             out << SP << fType << " * " << OpName << "_initial_hidden_state = fVec_" << OpName
0266                 << "_initial_hidden_state.data();\n";
0267          } else {
0268             out << SP << fType << " " << OpName << "_initial_hidden_state[" << num_directions * batch_size *
0269                 fAttrHiddenSize << "];\n";
0270          }
0271          for (size_t direction = 0; direction < num_directions; direction++) {
0272             out << SP << "for(size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
0273             out << SP << SP << "for(size_t h = 0; h < " << fAttrHiddenSize << "; h++) {\n";
0274             out << SP << SP << SP << OpName << "_initial_hidden_state["
0275                 << direction * batch_size * fAttrHiddenSize << " + batch * " << fAttrHiddenSize
0276                 << " + h] = tensor_" << fNInitial_h << "[batch * " << num_directions * fAttrHiddenSize
0277                 << " + " << direction * fAttrHiddenSize << " + h];\n";
0278             out << SP << SP << "}\n";
0279             out << SP << "}\n";
0280          }
0281       }
0282    }
0283 
0284    // Set the feedforward
0285    size_t feedforward_size = seq_length * batch_size * fAttrHiddenSize;
0286    if (fUseSession) {
0287       out << SP << fType << " * " << OpName << "_f_update_gate = fVec_" << OpName << "_f_update_gate.data();\n";
0288       out << SP << fType << " * " << OpName << "_f_reset_gate = fVec_" << OpName << "_f_reset_gate.data();\n";
0289       out << SP << fType << " * " << OpName << "_f_hidden_gate = fVec_" << OpName << "_f_hidden_gate.data();\n";
0290    } else {
0291       out << SP << fType << " " << OpName << "_f_update_gate[" << feedforward_size << "] = {0};\n";
0292       out << SP << fType << " " << OpName << "_f_reset_gate[" << feedforward_size << "] = {0};\n";
0293       out << SP << fType << " " << OpName << "_f_hidden_gate[" << feedforward_size << "] = {0};\n";
0294    }
0295    // Set the gates
0296    size_t hidden_state_size = seq_length * num_directions * batch_size * fAttrHiddenSize;
0297    if (fUseSession) {
0298       out << SP << fType << " * " << OpName << "_update_gate = fVec_" << OpName << "_update_gate.data();\n";
0299       out << SP << fType << " * " << OpName << "_reset_gate = fVec_" << OpName << "_reset_gate.data();\n";
0300       out << SP << fType << " * " << OpName << "_hidden_gate = fVec_" << OpName << "_hidden_gate.data();\n";
0301    } else {
0302       out << SP << fType << " " << OpName << "_update_gate[" << hidden_state_size << "] = {0};\n";
0303       out << SP << fType << " " << OpName << "_reset_gate[" << hidden_state_size << "] = {0};\n";
0304       out << SP << fType << " " << OpName << "_hidden_gate[" << hidden_state_size << "] = {0};\n";
0305    }
0306    // Set the hidden state
0307    if (fAttrLayout == 0 && !fNY.empty()) {
0308       out << SP << fType << " *" << OpName << "_hidden_state = tensor_" << fNY << ";\n";
0309    } else {
0310       if (fUseSession) {
0311          out << SP << fType << " * " << OpName << "_hidden_state = fVec_" << OpName << "_hidden_state.data();\n";
0312       } else {
0313          out << SP << fType << " " << OpName << "_hidden_state[" << hidden_state_size << "] = {0};\n";
0314       }
0315    }
0316 
0317    if (fUseSession) {
0318       out << SP << fType << " * " << OpName << "_feedback = fVec_" << OpName << "_feedback.data();\n";
0319    } else {
0320       out << SP << fType << " " << OpName << "_feedback[" << batch_size * fAttrHiddenSize << "] = {0};\n";
0321    }
0322 
0323    out << SP << "char " << OpName << "_transA = 'N';\n";
0324    out << SP << "char " << OpName << "_transB = 'T';\n";
0325    out << SP << "int " << OpName << "_m = " << seq_length * batch_size << ";\n";
0326    out << SP << "int " << OpName << "_m2 = " << batch_size << ";\n";
0327    out << SP << "int " << OpName << "_n = " << fAttrHiddenSize << ";\n";
0328    out << SP << "int " << OpName << "_k = " << input_size << ";\n";
0329    if (fType == "float") {
0330       out << SP << "float " << OpName << "_alpha = 1.;\n";
0331       out << SP << "float " << OpName << "_beta = 0.;\n";
0332    }
0333    if (!fNB.empty()) {
0334       out << SP << "int " << OpName << "_bias_size = " << seq_length * batch_size * fAttrHiddenSize << ";\n";
0335    }
0336    out << SP << "int " << OpName << "_incx = 1;\n";
0337    out << SP << "int " << OpName << "_incy = 1;\n";
0338    out << SP << "int " << OpName << "_feedback_size = " << batch_size * fAttrHiddenSize << ";\n";
0339 
0340    for (size_t direction = 0; direction < num_directions; direction++) {
0341       if (direction == 0) {
0342          if (fType == "float") {
0343             // f_update_gate = input * weight_z^T
0344             out << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
0345                 << OpName <<"_n, &" << OpName << "_m, &" << OpName << "_k, &" << OpName << "_alpha, tensor_"
0346                 << fNW << ", &" << OpName << "_k, " << OpName << "_input, &" << OpName << "_k, &"
0347                << OpName << "_beta, " << OpName << "_f_update_gate, &" << OpName << "_n);\n";
0348             // f_reset_gate = input * weight_r^T
0349             size_t wr_offset = fAttrHiddenSize * input_size;
0350             out << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
0351                 << OpName <<"_n, &" << OpName << "_m, &" << OpName << "_k, &" << OpName << "_alpha, tensor_"
0352                << fNW << " + " << wr_offset << ", &" << OpName << "_k, " << OpName << "_input, &"
0353                << OpName << "_k, &" << OpName << "_beta, " << OpName << "_f_reset_gate, &" << OpName << "_n);\n";
0354             // f_hidden_gate = input * weight_h^T
0355             size_t wh_offset = 2 * fAttrHiddenSize * input_size;
0356             out << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
0357                 << OpName <<"_n, &" << OpName << "_m, &" << OpName << "_k, &" << OpName << "_alpha, tensor_"
0358                << fNW << " + " << wh_offset << ", &" << OpName << "_k, " << OpName << "_input, &"
0359                << OpName << "_k, &" << OpName << "_beta, " << OpName << "_f_hidden_gate, &" << OpName << "_n);\n";
0360          }
0361       } else {
0362          if (fType == "float") {
0363             // f_update_gate = input * weight_z^T
0364             size_t wz_offset = 3 * fAttrHiddenSize * input_size;
0365             out << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
0366                 << OpName <<"_n, &" << OpName << "_m, &" << OpName << "_k, &" << OpName << "_alpha, tensor_"
0367                << fNW << " + " << wz_offset << ", &" << OpName << "_k, " << OpName << "_input, &"
0368                << OpName << "_k, &" << OpName << "_beta, " << OpName << "_f_update_gate, &" << OpName << "_n);\n";
0369             // f_reset_gate = input * weight_r^T
0370             size_t wr_offset = 3 * fAttrHiddenSize * input_size + fAttrHiddenSize * input_size;
0371             out << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
0372                 << OpName <<"_n, &" << OpName << "_m, &" << OpName << "_k, &" << OpName << "_alpha, tensor_"
0373                << fNW << " + " << wr_offset << ", &" << OpName << "_k, " << OpName << "_input, &"
0374                << OpName << "_k, &" << OpName << "_beta, " << OpName << "_f_reset_gate, &" << OpName << "_n);\n";
0375             // f_hidden_gate = input * weight_h^T
0376             size_t wh_offset = 3 * fAttrHiddenSize * input_size + 2 * fAttrHiddenSize * input_size;
0377             out << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
0378                 << OpName <<"_n, &" << OpName << "_m, &" << OpName << "_k, &" << OpName << "_alpha, tensor_"
0379                << fNW << " + " << wh_offset << ", &" << OpName << "_k, " << OpName << "_input, &"
0380                << OpName << "_k, &" << OpName << "_beta, " << OpName << "_f_hidden_gate, &" << OpName << "_n);\n";
0381          }
0382       }
0383 
0384       if (!fNB.empty()) {
0385          if (direction == 0) {
0386             if (fType == "float") {
0387                // Add the bias of the weight to f_update_gate
0388                out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_"
0389                    << fNB << ", &" << OpName << "_incx, " << OpName << "_f_update_gate, &" << OpName << "_incy);\n";
0390                // Add the bias of the recurrence to f_update_gate
0391                size_t rbz_offset = 3 * batch_size * seq_length * fAttrHiddenSize;
0392                out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_"
0393                    << fNB << " + " << rbz_offset << ", &" << OpName << "_incx, " << OpName << "_f_update_gate, &"
0394                    << OpName << "_incy);\n";
0395                // Add the bias of the weight to f_reset_gate
0396                size_t wbr_offset = batch_size * seq_length * fAttrHiddenSize;
0397                out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_"
0398                    << fNB << " + " << wbr_offset << ", &" << OpName << "_incx, " << OpName << "_f_reset_gate, &"
0399                    << OpName << "_incy);\n";
0400                // Add the bias of the recurrence to f_reset_gate
0401                //size_t rbr_offset = fAttrHiddenSize * fAttrHiddenSize + 3 * batch_size * fAttrHiddenSize;
0402                size_t rbr_offset = 4 * batch_size * seq_length * fAttrHiddenSize;
0403                out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_"
0404                    << fNB << " + " << rbr_offset << ", &" << OpName << "_incx, " << OpName << "_f_reset_gate, &"
0405                    << OpName << "_incy);\n";
0406                // Add the bias of the weight to f_hidden_gate
0407                size_t wbh_offset = 2 * batch_size * seq_length * fAttrHiddenSize;
0408                out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_"
0409                    << fNB << " + " << wbh_offset << ", &" << OpName << "_incx, " << OpName << "_f_hidden_gate, &"
0410                    << OpName << "_incy);\n";
0411                if (fAttrLinearBeforeReset == 0) {
0412                   // Add the bias of the recurrence to f_hidden_gate
0413                   size_t rbh_offset = 5 * batch_size * seq_length * fAttrHiddenSize;
0414                   out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_"
0415                       << fNB << " + " << rbh_offset << ", &" << OpName << "_incx, " << OpName
0416                       << "_f_hidden_gate, &" << OpName << "_incy);\n";
0417                }
0418             }
0419          } else {
0420             if (fType == "float") {
0421                // Add the bias of the weight to f_update_gate
0422                size_t wbz_offset = 6 * batch_size * seq_length * fAttrHiddenSize;
0423                out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_"
0424                    << fNB << " + " << wbz_offset << ", &" << OpName << "_incx, " << OpName << "_f_update_gate, &"
0425                    << OpName << "_incy);\n";
0426                // Add the bias of the recurrence to f_update_gate
0427                // size_t rbz_offset = 3 * fAttrHiddenSize * fAttrHiddenSize + 3 * batch_size * fAttrHiddenSize;
0428                size_t rbz_offset = 9 * batch_size * seq_length * fAttrHiddenSize;
0429                out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_"
0430                    << fNB << " + " << rbz_offset << ", &" << OpName << "_incx, " << OpName << "_f_update_gate, &"
0431                    << OpName << "_incy);\n";
0432                // Add the bias of the weight to f_reset_gate
0433                size_t wbr_offset =  7 * batch_size * seq_length * fAttrHiddenSize;
0434                out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_"
0435                    << fNB << " + " << wbr_offset << ", &" << OpName << "_incx, " << OpName << "_f_reset_gate, &"
0436                    << OpName << "_incy);\n";
0437                // Add the bias of the recurrence to f_reset_gate
0438                size_t rbr_offset = 10 * batch_size * seq_length * fAttrHiddenSize;
0439                out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_"
0440                    << fNB << " + " << rbr_offset << ", &" << OpName << "_incx, " << OpName << "_f_reset_gate, &"
0441                    << OpName << "_incy);\n";
0442                // Add the bias of the weight to f_hidden_gate
0443                size_t wbh_offset = 8 * batch_size * seq_length * fAttrHiddenSize;
0444                out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_"
0445                    << fNB << " + " << wbh_offset << ", &" << OpName << "_incx, " << OpName << "_f_hidden_gate, &"
0446                    << OpName << "_incy);\n";
0447                if (fAttrLinearBeforeReset == 0) {
0448                   // Add the bias of the recurrence to f_hidden_gate
0449                   size_t rbh_offset = 11 * batch_size * seq_length * fAttrHiddenSize;
0450                   out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_"
0451                       << fNB << " + " << rbh_offset << ", &" << OpName << "_incx, " << OpName
0452                       << "_f_hidden_gate, &" << OpName << "_incy);\n";
0453                }
0454             }
0455          }
0456       }
0457 
0458       // Copy the feedforward into the gates
0459       out << SP << "for (size_t seq = 0; seq < " << seq_length << "; seq++) {\n";
0460       out << SP << SP << "size_t offset = seq * " << batch_size * fAttrHiddenSize << ";\n";
0461       if (direction == 0) {
0462          out << SP << SP << "size_t gate_offset = seq * " << num_directions * batch_size * fAttrHiddenSize
0463             << ";\n";
0464       } else {
0465          out << SP << SP << "size_t gate_offset = seq * " << num_directions * batch_size * fAttrHiddenSize
0466              << " + " << batch_size * fAttrHiddenSize << ";\n";
0467       }
0468       size_t f_seq_size = batch_size * fAttrHiddenSize;
0469       out << SP << SP << "std::copy(" << OpName << "_f_update_gate + offset, " << OpName
0470           << "_f_update_gate + offset + " << f_seq_size << ", " << OpName << "_update_gate + gate_offset);\n";
0471       out << SP << SP << "std::copy(" << OpName << "_f_reset_gate + offset, " << OpName
0472           << "_f_reset_gate + offset + " << f_seq_size << ", " << OpName << "_reset_gate + gate_offset);\n";
0473       out << SP << SP << "std::copy(" << OpName << "_f_hidden_gate + offset, " << OpName
0474           << "_f_hidden_gate + offset + " << f_seq_size << ", " << OpName << "_hidden_gate + gate_offset);\n";
0475       out << SP << "}\n";
0476 
0477       out << SP << "for (size_t seq = 0; seq < " << seq_length << "; seq++) {\n";
0478       if (fAttrDirection == "backward" || direction == 1) {
0479          out << SP << SP << "size_t index = " << seq_length - 1 << " - seq;\n";
0480       } else {
0481          out << SP << SP << "size_t index = seq;\n";
0482       }
0483       out << SP << SP << "int m2 = " << batch_size << ";\n";
0484       if (direction == 0) {
0485          out << SP << SP << "size_t offset = index * " << num_directions * batch_size * fAttrHiddenSize
0486               << ";\n";
0487       } else {
0488          out << SP << SP << "size_t offset = index * " << num_directions * batch_size * fAttrHiddenSize
0489              << " + " << batch_size * fAttrHiddenSize << ";\n";
0490       }
0491       size_t size = batch_size * fAttrHiddenSize;
0492       // gate = gate + initial_hidden_state * Recurrence^T
0493       out << SP << SP << "if (seq == 0) {\n";
0494       if (!fNInitial_h.empty()) {
0495          if (direction == 0) {
0496             if (fType == "float") {
0497                out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
0498                    << OpName << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << ", &"
0499                    << OpName << "_n, " << OpName << "_initial_hidden_state, &" << OpName << "_n, &" << OpName
0500                    << "_alpha, " << OpName << "_update_gate + offset, &" << OpName << "_n);\n";
0501                size_t rr_offset = fAttrHiddenSize * fAttrHiddenSize;
0502                out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
0503                    << OpName << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << " + "
0504                    << rr_offset << ", &" << OpName << "_n, " << OpName << "_initial_hidden_state, &" << OpName
0505                    << "_n, &" << OpName << "_alpha, " << OpName << "_reset_gate + offset, &" << OpName << "_n);\n";
0506             }
0507          } else { // direction=1
0508             if (fType == "float") {
0509                size_t rz_offset = 3 * fAttrHiddenSize * fAttrHiddenSize;
0510                out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
0511                    << OpName << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << " + "
0512                    << rz_offset << ", &" << OpName << "_n, " << OpName << "_initial_hidden_state, &" << OpName
0513                    << "_n, &" << OpName << "_alpha, " << OpName << "_update_gate + offset, &" << OpName << "_n);\n";
0514                size_t rr_offset = 4 * fAttrHiddenSize * fAttrHiddenSize;
0515                out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
0516                    << OpName << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << " + "
0517                    << rr_offset << ", &" << OpName << "_n, " << OpName << "_initial_hidden_state, &" << OpName
0518                    << "_n, &" << OpName << "_alpha, " << OpName << "_reset_gate + offset, &" << OpName << "_n);\n";
0519             }
0520          }
0521       }
0522       out << SP << SP << "} else {\n";
0523       // gate = gate + previous_hidden_state * Recurrence^T
0524       if (direction == 0) {
0525          if (fAttrDirection == "backward") {
0526             out << SP << SP << SP << "size_t previous_offset = (index + 1) * "
0527                 << num_directions * batch_size * fAttrHiddenSize << ";\n";
0528          } else {
0529             out << SP << SP << SP << "size_t previous_offset = (seq - 1) * "
0530                 << num_directions * batch_size * fAttrHiddenSize << ";\n";
0531          }
0532          if (fType == "float") {
0533             out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
0534              << OpName << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << ", &"
0535              << OpName << "_n, " << OpName << "_hidden_state + previous_offset, &" << OpName << "_n, &"
0536              << OpName << "_alpha, " << OpName << "_update_gate + offset, &" << OpName << "_n);\n";
0537             size_t rr_offset = fAttrHiddenSize * fAttrHiddenSize;
0538             out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
0539              << OpName << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << " + "
0540              << rr_offset << ", &" << OpName << "_n, " << OpName << "_hidden_state + previous_offset, &"
0541              << OpName << "_n, &" << OpName << "_alpha, " << OpName << "_reset_gate + offset, &"
0542              << OpName << "_n);\n";
0543          }
0544       } else {
0545          out << SP << SP << SP << "size_t previous_offset = (index + 1) * "
0546              << num_directions * batch_size * fAttrHiddenSize << " + " << batch_size * fAttrHiddenSize << ";\n";
0547          if (fType == "float") {
0548             size_t rz_offset = 3 * fAttrHiddenSize * fAttrHiddenSize;
0549             out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
0550              << OpName << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << " + "
0551              << rz_offset << ", &" << OpName << "_n, " << OpName << "_hidden_state + previous_offset, &"
0552              << OpName << "_n, &" << OpName << "_alpha, " << OpName << "_update_gate + offset, &"
0553              << OpName << "_n);\n";
0554             size_t rr_offset = 4 * fAttrHiddenSize * fAttrHiddenSize;
0555             out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
0556              << OpName << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << " + "
0557              << rr_offset << ", &" << OpName << "_n, " << OpName << "_hidden_state + previous_offset, &"
0558              << OpName << "_n, &" << OpName << "_alpha, " << OpName << "_reset_gate + offset, &"
0559              << OpName << "_n);\n";
0560          }
0561       }
0562       out << SP << SP << "}\n";
0563 
0564       // Clip the elements of the update gate and the reset gate into the range [-fClip, fClip]
0565       if (fAttrClip > .0) {
0566          out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0567          if (fType == "float") {
0568             out << SP << SP << SP << "float z = (" << OpName << "_update_gate[i] > " << -fAttrClip
0569                 << ") ? " << OpName << "_update_gate[i] : " << -fAttrClip << ";\n";
0570          }
0571          out << SP << SP << SP << OpName << "_update_gate[i] = (z < " << fAttrClip
0572              << ") ? z : " << fAttrClip << ";\n";
0573          if (fType == "float") {
0574             out << SP << SP << SP << "float r = (" << OpName << "_reset_gate[i] > " << -fAttrClip
0575                 << ") ? " << OpName << "_reset_gate[i] : " << -fAttrClip << ";\n";
0576          }
0577          out << SP << SP << SP << OpName << "_reset_gate[i] = (r < " << fAttrClip
0578              << ") ? r : " << fAttrClip << ";\n";
0579          out << SP << SP << "}\n";
0580       }
0581 
0582       // Apply the activation function to the update gate and the reset gate
0583       if (fAttrActivations[direction * 2] == "Relu") {
0584          out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0585          out << SP << SP << SP << "if (" << OpName << "_update_gate[i] < 0.)\n";
0586          out << SP << SP << SP << SP << OpName << "_update_gate[i] = 0.;\n";
0587          out << SP << SP << SP << "if (" << OpName << "_reset_gate[i] < 0.)\n";
0588          out << SP << SP << SP << SP << OpName << "_reset_gate[i] = 0.;\n";
0589          out << SP << SP << "}\n";
0590       } else if (fAttrActivations[direction * 2] == "Tanh") {
0591          out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0592          if (fType == "float") {
0593             out << SP << SP << SP << "float z = exp(-2 * " << OpName << "_update_gate[i]);\n";
0594          }
0595          out << SP << SP << SP << SP << OpName << "_update_gate[i] = (1. - z) / (1. + z);\n";
0596          if (fType == "float") {
0597             out << SP << SP << SP << "float r = exp(-2 * " << OpName << "_reset_gate[i]);\n";
0598          }
0599          out << SP << SP << SP << SP << OpName << "_reset_gate[i] = (1. - r) / (1. + r);\n";
0600          out << SP << SP << "}\n";
0601       } else if (fAttrActivations[direction * 2] == "Sigmoid") {
0602          out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0603          out << SP << SP << SP << SP << OpName << "_update_gate[i] = 1. / (1. + exp(-"
0604              << OpName << "_update_gate[i]));\n";
0605          out << SP << SP << SP << SP << OpName << "_reset_gate[i] = 1. / (1. + exp(-"
0606              << OpName << "_reset_gate[i]));\n";
0607          out << SP << SP << "}\n";
0608       } else if (fAttrActivations[direction * 2] == "Affine") {
0609          out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0610          out << SP << SP << SP << SP << OpName << "_update_gate[i] = "
0611              << fAttrActivationAlpha[direction * 2] << " * " << OpName << "_update_gate[i] + "
0612              << fAttrActivationBeta[direction * 2] << ";\n";
0613          out << SP << SP << SP << SP << OpName << "_reset_gate[i] = "
0614              << fAttrActivationAlpha[direction * 2] << " * " << OpName << "_reset_gate[i] + "
0615              << fAttrActivationBeta[direction * 2] << ";\n";
0616          out << SP << SP << "}\n";
0617       } else if (fAttrActivations[direction * 2] == "ScaledTanh") {
0618          out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0619          if (fType == "float") {
0620             out << SP << SP << SP << "float z = exp(-2 * " << fAttrActivationBeta[direction * 2]
0621                 << " * "<< OpName << "_update_gate[i]);\n";
0622             }
0623             out << SP << SP << SP << SP << OpName << "_update_gate[i] = "
0624                 << fAttrActivationAlpha[direction * 2] << " * (1. - z) / (1. + z);\n";
0625          if (fType == "float") {
0626             out << SP << SP << SP << "float r = exp(-2 * " << fAttrActivationBeta[direction * 2]
0627                 << " * "<< OpName << "_reset_gate[i]);\n";
0628             }
0629             out << SP << SP << SP << SP << OpName << "_reset_gate[i] = "
0630                 << fAttrActivationAlpha[direction * 2] << " * (1. - r) / (1. + r);\n";
0631          out << SP << SP << "}\n";
0632       } else if (fAttrActivations[direction * 2] == "HardSigmoid") {
0633          out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0634          if (fType == "float") {
0635             out << SP << SP << SP << "float za = " << fAttrActivationAlpha[direction * 2] << " * "
0636                 << OpName << "_update_gate[i] + " << fAttrActivationBeta[direction * 2] << ";\n";
0637             out << SP << SP << SP << "float zb = (za > 0.) ? za : 0.;\n";
0638          }
0639          out << SP << SP << SP << SP << OpName << "_update_gate[i] = (zb < 1.) ? zb : 1.;\n";
0640          if (fType == "float") {
0641             out << SP << SP << SP << "float ra = " << fAttrActivationAlpha[direction * 2] << " * "
0642                 << OpName << "_reset_gate[i] + " << fAttrActivationBeta[direction * 2] << ";\n";
0643             out << SP << SP << SP << "float rb = (ra > 0.) ? ra : 0.;\n";
0644          }
0645          out << SP << SP << SP << SP << OpName << "_reset_gate[i] = (rb < 1.) ? rb : 1.;\n";
0646          out << SP << SP << "}\n";
0647       } else if (fAttrActivations[direction * 2] == "LeakyRelu") {
0648          out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0649          out << SP << SP << SP << "if (" << OpName << "_update_gate[i] < 0.)\n";
0650          out << SP << SP << SP << SP << OpName << "_update_gate[i] = "
0651              << fAttrActivationAlpha[direction * 2] << " * " << OpName << "_update_gate[i];\n";
0652          out << SP << SP << SP << "if (" << OpName << "_reset_gate[i] < 0.)\n";
0653          out << SP << SP << SP << SP << OpName << "_reset_gate[i] = "
0654              << fAttrActivationAlpha[direction * 2] << " * " << OpName << "_reset_gate[i];\n";
0655          out << SP << SP << "}\n";
0656       } else if (fAttrActivations[direction * 2] == "ThresholdRelu") {
0657          out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0658          out << SP << SP << SP << "if (" << OpName << "_update_gate[i] < "
0659              << fAttrActivationAlpha[direction * 2] << ")\n";
0660          out << SP << SP << SP << SP << OpName << "_update_gate[i] = 0.;\n";
0661          out << SP << SP << SP << "if (" << OpName << "_reset_gate[i] < "
0662              << fAttrActivationAlpha[direction * 2] << ")\n";
0663          out << SP << SP << SP << SP << OpName << "_reset_gate[i] = 0.;\n";
0664          out << SP << SP << "}";
0665       } else if (fAttrActivations[direction * 2] == "Elu") {
0666          out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0667          out << SP << SP << SP << "if (" << OpName << "_update_gate[i] < 0.)\n";
0668          out << SP << SP << SP << SP << OpName << "_update_gate[i] = "
0669              << fAttrActivationAlpha[direction * 2] << " * exp(" << OpName << "_update_gate[i] - 1.);\n";
0670          out << SP << SP << SP << "if (" << OpName << "_reset_gate[i] < 0.)\n";
0671          out << SP << SP << SP << SP << OpName << "_reset_gate[i] = "
0672              << fAttrActivationAlpha[direction * 2] << " * exp(" << OpName << "_reset_gate[i] - 1.);\n";
0673          out << SP << SP << "}\n";
0674       } else if (fAttrActivations[direction * 2] == "Softsign") {
0675          out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0676          out << SP << SP << SP << SP << OpName << "_update_gate[i] = " << OpName
0677              << "_update_gate[i] / (1. + abs(" << OpName << "_update_gate[i]));\n";
0678          out << SP << SP << SP << SP << OpName << "_reset_gate[i] = " << OpName
0679              << "_reset_gate[i] / (1. + abs(" << OpName << "_reset_gate[i]));\n";
0680          out << SP << SP << "}\n";
0681       } else { // fAttrActivations[direction * 2] = Softplus
0682          out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0683          out << SP << SP << SP << SP << OpName << "_update_gate[i] = log(1. + exp("
0684              << OpName << "_update_gate[i]));\n";
0685          out << SP << SP << SP << SP << OpName << "_reset_gate[i] = log(1. + exp("
0686              << OpName << "_reset_gate[i]));\n";
0687          out << SP << SP << "}\n";
0688       }
0689 
0690       if (fAttrLinearBeforeReset == 0) {
0691          out << SP << SP << "if (seq == 0) {\n";
0692          if (!fNInitial_h.empty()) {
0693             // feedback = reset_gate o initial_hidden_state
0694             out << SP << SP << SP << "for (size_t i = 0; i < " << size << "; i++) {\n";
0695             out << SP << SP << SP << SP << OpName << "_feedback[i] = " << OpName
0696                 << "_reset_gate[i + offset] * " << OpName << "_initial_hidden_state[i];\n";
0697             out << SP << SP << SP << "}\n";
0698          }
0699          out << SP << SP << "} else {\n";
0700          // feedback = reset_gate o previous_hidden_state
0701          if (direction == 0) {
0702             if (fAttrDirection == "backward") {
0703                out << SP << SP << SP << "size_t previous_offset = (index + 1) * "
0704                    << num_directions * batch_size * fAttrHiddenSize << ";\n";
0705             } else {
0706                out << SP << SP << SP << "size_t previous_offset = (seq - 1) * "
0707                    << num_directions * batch_size * fAttrHiddenSize << ";\n";
0708             }
0709          } else {
0710             out << SP << SP << SP << "size_t previous_offset = (index + 1) * " << num_directions
0711                 * batch_size * fAttrHiddenSize << " + " << batch_size * fAttrHiddenSize << ";\n";
0712          }
0713          out << SP << SP << SP << "for (size_t i = 0; i < " << size << "; i++) {\n";
0714          out << SP << SP << SP << SP << OpName << "_feedback[i] = " << OpName
0715              << "_reset_gate[i + offset] * " << OpName << "_hidden_state[i + previous_offset];\n";
0716          out << SP << SP << SP << "}\n";
0717          out << SP << SP << "}\n";
0718          // feedback = feedback * R_h^T
0719          size_t rh_offset = (direction == 0) ?
0720             2 * fAttrHiddenSize * fAttrHiddenSize : 3 * fAttrHiddenSize * fAttrHiddenSize
0721             + 2 * fAttrHiddenSize * fAttrHiddenSize;
0722          out << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
0723              << OpName << "_n, &" << OpName << "_m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_"
0724              << fNR << " + " << rh_offset << ", &" << OpName << "_n, " << OpName << "_feedback, &" << OpName
0725              << "_n, &" << OpName << "_beta, " << OpName << "_feedback, &" << OpName << "_n);\n";
0726       } else { // fAttrLinearBeforeReset=1
0727          // feedback = previous_hidden_state * R_h^T
0728          //LM fixes
0729          size_t rh_offset = (direction == 0)
0730                                ? 2 * fAttrHiddenSize * fAttrHiddenSize
0731                                : 3 * fAttrHiddenSize * fAttrHiddenSize + 2 * fAttrHiddenSize * fAttrHiddenSize;
0732          out << SP << SP << "if (seq == 0) {\n";
0733          if (!fNInitial_h.empty()) {
0734             // feedback = W * initial_hidden_state + bias
0735             out << SP << SP << SP
0736                << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &" << OpName << "_n, &"
0737                << OpName << "_m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << " + "
0738                << rh_offset << ", &" << OpName << "_n, " << OpName << "_initial_hidden_state, &"
0739                << OpName << "_n, &" << OpName << "_beta, " << OpName << "_feedback, &" << OpName << "_n);\n";
0740          }
0741          out << SP << SP << "} else {\n";
0742          // case for seq > 0
0743          if (direction == 0) {
0744             if (fAttrDirection == "backward") {
0745                out << SP << SP << SP << "size_t previous_offset = (index + 1) * "
0746                   << num_directions * batch_size * fAttrHiddenSize << ";\n";
0747             } else {
0748                out << SP << SP  << SP << "size_t previous_offset = (seq - 1) * "
0749                   << num_directions * batch_size * fAttrHiddenSize << ";\n";
0750             }
0751          } else {
0752             out << SP << SP <<  SP << "size_t previous_offset = (index + 1) * " << num_directions
0753                 * batch_size * fAttrHiddenSize << " + " << batch_size * fAttrHiddenSize << ";\n";
0754          }
0755          out << SP << SP  << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
0756              << OpName << "_n, &" << OpName << "_m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR
0757              << " + " << rh_offset << ", &" << OpName << "_n, " << OpName << "_hidden_state + previous_offset, &"
0758              << OpName << "_n, &" << OpName << "_beta, " << OpName << "_feedback, &" << OpName << "_n);\n";
0759          // endif on seq 0 or not
0760          out << SP << SP << "}\n";
0761          // Add the bias of the recurrence to feedback
0762          if (!fNB.empty()) {
0763             size_t rbh_offset = (direction == 0) ? 5 * batch_size * seq_length * fAttrHiddenSize
0764                                                  : 11 * batch_size * seq_length * fAttrHiddenSize;
0765             out << SP << SP << "BLAS::saxpy_(&" << OpName << "_feedback_size, &" << OpName
0766              << "_alpha, tensor_" << fNB << " + " << rbh_offset << ", &" << OpName << "_incx, "
0767              << OpName << "_feedback, &" << OpName << "_incy);\n";
0768          }
0769          // feedback = reset_gate o feedback
0770          out << SP << SP << "for (size_t i = 0; i < " << size << "; i++) {\n";
0771          out << SP << SP << SP << OpName << "_feedback[i] *= " << OpName << "_reset_gate[i + offset];\n";
0772          out << SP << SP << "}\n";
0773       }
0774 
0775       // hidden_gate = hidden_gate + feedback
0776       out << SP << SP << "BLAS::saxpy_(&" << OpName << "_feedback_size, &" << OpName << "_alpha, "
0777           << OpName << "_feedback, &" << OpName << "_incx, " << OpName << "_hidden_gate + offset, &"
0778           << OpName << "_incy);\n";
0779 
0780       // Clip the elements of the hidden gate into the range [-fClip, fClip]
0781       if (fAttrClip > .0) {
0782          out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0783          if (fType == "float") {
0784             out << SP << SP << SP << "float x = (" << OpName << "_hidden_gate[i] > " << -fAttrClip
0785                 << ") ? " << OpName << "_hidden_gate[i] : " << -fAttrClip << ";\n";
0786          }
0787          out << SP << SP << SP << OpName << "_hidden_gate[i] = (x < " << fAttrClip << ") ? x : "
0788              << fAttrClip << ";\n";
0789          out << SP << SP << "}\n";
0790       }
0791 
0792       // Apply the activation function to the hidden gate
0793       if (fAttrActivations[direction * 2 + 1] == "Relu") {
0794          out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0795          out << SP << SP << SP << "if (" << OpName << "_hidden_gate[i] < 0.)\n";
0796          out << SP << SP << SP << SP << OpName << "_hidden_gate[i] = 0.;\n";
0797          out << SP << SP << "}\n";
0798       } else if (fAttrActivations[direction * 2 + 1] == "Tanh") {
0799          out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0800          if (fType == "float") {
0801             out << SP << SP << SP << "float ex = exp(-2 * " << OpName << "_hidden_gate[i]);\n";
0802          }
0803          out << SP << SP << SP << SP << OpName << "_hidden_gate[i] = (1. - ex) / (1. + ex);\n";
0804          out << SP << SP << "}\n";
0805       } else if (fAttrActivations[direction * 2 + 1] == "Sigmoid") {
0806          out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0807          out << SP << SP << SP << SP << OpName << "_hidden_gate[i] = 1. / (1. + exp(-" << OpName
0808              << "_hidden_gate[i]));\n";
0809          out << SP << SP << "}\n";
0810       } else if (fAttrActivations[direction * 2 + 1] == "Affine") {
0811          out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0812          out << SP << SP << SP << SP << OpName << "_hidden_gate[i] = "
0813              << fAttrActivationAlpha[direction * 2 + 1] << " * " << OpName << "_hidden_gate[i] + "
0814              << fAttrActivationBeta[direction * 2 + 1] << ";\n";
0815          out << SP << SP << "}\n";
0816       } else if (fAttrActivations[direction * 2 + 1] == "ScaledTanh") {
0817          out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0818          if (fType == "float") {
0819             out << SP << SP << SP << "float ex = exp(-2 * " << fAttrActivationBeta[direction * 2 + 1]
0820                 << " * "<< OpName << "_hidden_gate[i]);\n";
0821             }
0822             out << SP << SP << SP << SP << OpName << "_hidden_gate[i] = "
0823                 << fAttrActivationAlpha[direction * 2 + 1] << " * (1. - ex) / (1. + ex);\n";
0824          out << SP << SP << "}\n";
0825       } else if (fAttrActivations[direction * 2 + 1] == "HardSigmoid") {
0826          out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0827          if (fType == "float") {
0828             out << SP << SP << SP << "float a = " << fAttrActivationAlpha[direction * 2 + 1] << " * "
0829                 << OpName << "_hidden_gate[i] + " << fAttrActivationBeta[direction * 2 + 1] << ";\n";
0830             out << SP << SP << SP << "float b = (a > 0.) ? a : 0.;\n";
0831          }
0832          out << SP << SP << SP << SP << OpName << "_hidden_gate[i] = (b < 1.) ? b : 1.;\n";
0833          out << SP << SP << "}\n";
0834       } else if (fAttrActivations[direction * 2 + 1] == "LeakyRelu") {
0835          out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0836          out << SP << SP << SP << "if (" << OpName << "_hidden_gate[i] < 0.)\n";
0837          out << SP << SP << SP << SP << OpName << "_hidden_gate[i] = "
0838              << fAttrActivationAlpha[direction * 2 + 1] << " * " << OpName << "_hidden_gate[i];\n";
0839          out << SP << SP << "}\n";
0840       } else if (fAttrActivations[direction * 2 + 1] == "ThresholdRelu") {
0841          out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0842          out << SP << SP << SP << "if (" << OpName << "_hidden_gate[i] < "
0843              << fAttrActivationAlpha[direction * 2 + 1] << ")\n";
0844          out << SP << SP << SP << SP << OpName << "_hidden_gate[i] = 0.;\n";
0845          out << SP << SP << "}";
0846       } else if (fAttrActivations[direction * 2 + 1] == "Elu") {
0847          out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0848          out << SP << SP << SP << "if (" << OpName << "_hidden_gate[i] < 0.)\n";
0849          out << SP << SP << SP << SP << OpName << "_hidden_gate[i] = "
0850              << fAttrActivationAlpha[direction * 2 + 1] << " * exp(" << OpName << "_hidden_gate[i] - 1.);\n";
0851          out << SP << SP << "}\n";
0852       } else if (fAttrActivations[direction * 2 + 1] == "Softsign") {
0853          out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0854          out << SP << SP << SP << SP << OpName << "_hidden_gate[i] = " << OpName
0855              << "_hidden_gate[i] / (1. + abs(" << OpName << "_hidden_gate[i]));\n";
0856          out << SP << SP << "}\n";
0857       } else { // fAttrActivations[direction * 2 + 1] = Softplus
0858          out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0859          out << SP << SP << SP << SP << OpName << "_hidden_gate[i] = log(1. + exp("
0860              << OpName << "_hidden_gate[i]));\n";
0861          out << SP << SP << "}\n";
0862       }
0863 
0864       // hidden_state = (1 - update_gate) o hidden_gate
0865       out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0866       out << SP << SP << SP << OpName << "_hidden_state[i] = ( 1. - " << OpName
0867           << "_update_gate[i]) * " << OpName << "_hidden_gate[i];\n";
0868       out << SP << SP << "}\n";
0869 
0870       out << SP << SP << "if (seq == 0) {\n";
0871       if (!fNInitial_h.empty()) {
0872          // hidden_state += update_gate o initial_hidden_state
0873          out << SP << SP << SP << "for (size_t i = 0; i < " << size << "; i++) {\n";
0874          out << SP << SP << SP << SP << OpName << "_hidden_state[i + offset] += " << OpName
0875              << "_update_gate[i + offset] * " << OpName << "_initial_hidden_state[i];\n";
0876          out << SP << SP << SP << "}\n";
0877       }
0878       out << SP << SP << "} else {\n";
0879       // hidden_state += update_gate o previous_hidden_state
0880       if (direction == 0) {
0881          if (fAttrDirection == "backward") {
0882             out << SP << SP << SP << "size_t previous_offset = (index + 1) * "
0883                 << num_directions * batch_size * fAttrHiddenSize << ";\n";
0884          } else {
0885             out << SP << SP << SP << "size_t previous_offset = (seq - 1) * "
0886                 << num_directions * batch_size * fAttrHiddenSize << ";\n";
0887          }
0888       } else {
0889          out << SP << SP << SP << "size_t previous_offset = (index + 1) * "
0890              << num_directions * batch_size * fAttrHiddenSize << " + " << batch_size * fAttrHiddenSize << ";\n";
0891       }
0892       out << SP << SP << SP << "for (size_t i = 0; i < " << size << "; i++) {\n";
0893       out << SP << SP << SP << SP << OpName << "_hidden_state[i + offset] += " << OpName
0894           << "_update_gate[i + offset] * " << OpName << "_hidden_state[i + previous_offset];\n";
0895       out << SP << SP << SP << "}\n";
0896       out << SP << SP << "}\n";
0897 
0898       out << SP << "}\n";
0899    }
0900 
0901    // Padding the hidden state for GRU with different sequence lengths
0902    if (!fNSequence_lens.empty()) {
0903       out << SP << "for (size_t seq = 0; seq < " << seq_length << "; seq++) {\n";
0904       out << SP << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
0905       out << SP << SP << SP << "if (seq >= tensor_" << fNSequence_lens << "[batch]) {\n";
0906       for (size_t direction = 0; direction < num_directions; direction++) {
0907          out << SP << SP << SP << SP << SP << "for (size_t h = 0; h < " << fAttrHiddenSize << "; h++) {\n";
0908          out << SP << SP << SP << SP << SP << SP << OpName << "_hidden_state[seq * "
0909              << num_directions * batch_size * fAttrHiddenSize + direction * batch_size * fAttrHiddenSize
0910              << " + batch * " << fAttrHiddenSize << " + h] = 0.;\n";
0911          out << SP << SP << SP << SP << SP << "}\n";
0912       }
0913       out << SP << SP << SP << "}\n";
0914       out << SP << SP << "}\n";
0915       out << SP << "}\n";
0916    }
0917 
0918    // Copy the hidden state into y and y_h
0919    if (fAttrLayout == 0) {
0920       if (!fNY_h.empty()) {
0921          // Copy hidden_state into Y_h
0922          if (fNSequence_lens.empty()) {
0923             size_t yh_size = batch_size * fAttrHiddenSize;
0924             if (fAttrDirection == "backward") {
0925                out << SP << "std::copy(" << OpName << "_hidden_state, " << OpName << "_hidden_state + "
0926                    << yh_size << ", tensor_" << fNY_h << ");\n";
0927             } else {
0928                size_t offset = (seq_length - 1) * num_directions * batch_size * fAttrHiddenSize;
0929                out << SP << "std::copy(" << OpName << "_hidden_state + " << offset << ", " << OpName
0930                    << "_hidden_state + " << offset << " + " << yh_size << ", tensor_" << fNY_h << ");\n";
0931             }
0932             if (num_directions == 2) {
0933                out << SP << "std::copy(" << OpName << "_hidden_state + " << yh_size << ", " << OpName
0934                    << "_hidden_state + " << 2 * yh_size << ", tensor_" << fNY_h << " + " << yh_size << ");\n";
0935             }
0936          } else { // GRU with different sequence lengths
0937             if (fAttrDirection == "backward") {
0938                out << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
0939                out << SP << SP << "size_t offset = batch * " << fAttrHiddenSize << ";\n";
0940                out << SP << SP << "std::copy(" << OpName << "_hidden_state + offset, " << OpName
0941                    << "_hidden_state + offset + " << fAttrHiddenSize << ", tensor_" << fNY_h << " + offset);\n";
0942                out << SP << "}\n";
0943             } else {
0944                out << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
0945                out << SP << SP << "size_t seq = " << "tensor_" << fNSequence_lens << "[batch] - 1;\n";
0946                out << SP << SP << "size_t offset = seq * " << num_directions * batch_size * fAttrHiddenSize
0947                    << " + batch * " << fAttrHiddenSize << ";\n";
0948                out << SP << SP << "size_t yh_offset = batch * " << fAttrHiddenSize << ";\n";
0949                out << SP << SP << "std::copy(" << OpName << "_hidden_state + offset, " << OpName
0950                    << "_hidden_state + offset + " << fAttrHiddenSize << ", tensor_" << fNY_h << " + yh_offset);\n";
0951                out << SP << "}\n";
0952             }
0953             if (num_directions == 2) {
0954                out << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
0955                out << SP << SP << "size_t offset = " << batch_size * fAttrHiddenSize
0956                    << " + batch * " << fAttrHiddenSize << ";\n";
0957                out << SP << SP << "size_t yh_offset = " << batch_size * fAttrHiddenSize
0958                    << " + batch * " << fAttrHiddenSize << ";\n";
0959                out << SP << SP << "std::copy(" << OpName << "_hidden_state + offset, " << OpName
0960                    << "_hidden_state + offset + " << fAttrHiddenSize << ", tensor_" << fNY_h << " + yh_offset);\n";
0961                out << SP << "}\n";
0962             }
0963          }
0964       }
0965    } else { // fAttrLayout=1
0966       if (!fNY.empty()) {
0967          // Copy hidden_state into Y
0968          for (size_t direction = 0; direction < num_directions; direction++) {
0969             out << SP << "for (size_t seq = 0; seq < " << seq_length << "; seq++) {\n";
0970             out << SP << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
0971             out << SP << SP << SP << "size_t offset = seq * " << num_directions * batch_size * fAttrHiddenSize
0972                 << " + " << direction * batch_size * fAttrHiddenSize << " + batch * " << fAttrHiddenSize << ";\n";
0973             out << SP << SP << SP << "size_t y_offset = batch * " << seq_length * num_directions * fAttrHiddenSize
0974                 << " + seq * " << num_directions * fAttrHiddenSize << " + " << direction * fAttrHiddenSize << ";\n";
0975             out << SP << SP << SP << "std::copy(" << OpName << "_hidden_state + offset, " << OpName
0976                 << "_hidden_state + offset + " << fAttrHiddenSize << ", tensor_" << fNY << " + y_offset);\n";
0977             out << SP << SP << "}\n";
0978             out << SP << "}\n";
0979          }
0980       }
0981       if (!fNY_h.empty()) {
0982          // Copy the hidden_state into Y_h
0983          if (fAttrDirection == "backward") {
0984             out << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
0985             out << SP << SP << "size_t offset = batch * " << fAttrHiddenSize << ";\n";
0986             out << SP << SP << "size_t yh_offset = batch * " << num_directions * fAttrHiddenSize << ";\n";
0987             out << SP << SP << "std::copy(" << OpName << "_hidden_state + offset, " << OpName
0988                 << "_hidden_state + offset + " << fAttrHiddenSize << ", tensor_" << fNY_h << " + yh_offset);\n";
0989             out << SP << "}\n";
0990          } else {
0991             out << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
0992             if (fNSequence_lens.empty()) {
0993                out << SP << SP << "size_t seq = " << seq_length - 1 << ";\n";
0994             } else {
0995                out << SP << SP << "size_t seq = " << "tensor_" << fNSequence_lens << "[batch] - 1;\n";
0996             }
0997             out << SP << SP << "size_t offset = seq * " << num_directions * batch_size * fAttrHiddenSize
0998                 << " + batch * " << fAttrHiddenSize << ";\n";
0999             out << SP << SP << "size_t yh_offset = batch * " << num_directions * fAttrHiddenSize << ";\n";
1000             out << SP << SP << "std::copy(" << OpName << "_hidden_state + offset, " << OpName
1001                 << "_hidden_state + offset + " << fAttrHiddenSize << ", tensor_" << fNY_h << " + yh_offset);\n";
1002             out << SP << "}\n";
1003          }
1004          if (num_directions == 2) {
1005             out << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
1006             out << SP << SP << "size_t offset = " << batch_size * fAttrHiddenSize << " + batch * "
1007                 << fAttrHiddenSize << ";\n";
1008             out << SP << SP << "size_t yh_offset = batch * " << num_directions * fAttrHiddenSize << " + "
1009                 << fAttrHiddenSize << ";\n";
1010             out << SP << SP << "std::copy(" << OpName << "_hidden_state + offset, " << OpName
1011                 << "_hidden_state + offset + " << fAttrHiddenSize << ", tensor_" << fNY_h << " + yh_offset);\n";
1012             out << SP << "}\n";
1013          }
1014       }
1015    }
1016 
1017    return out.str();
1018 }
1019 
1020 } // namespace SOFIE
1021 } // namespace Experimental
1022 } // namespace TMVA
1023 
1024 #endif