File indexing completed on 2025-09-12 09:09:39
0001 #ifndef TMVA_SOFIE_ROPERATOR_GRU_I
0002 #define TMVA_SOFIE_ROPERATOR_GRU_I
0003
0004 namespace TMVA {
0005 namespace Experimental {
0006 namespace SOFIE {
0007
0008 template <typename T>
0009 auto ROperator_GRU<T>::TypeInference(std::vector<ETensorType> input)
0010 -> std::vector<ETensorType> {
0011 ETensorType out = input[0];
0012 return {out, out};
0013 }
0014
0015 template<typename T>
0016 auto ROperator_GRU<T>::ShapeInference(std::vector<std::vector<size_t>> input)
0017 -> std::vector<std::vector<size_t>> {
0018 size_t num_directions = input[1][0];
0019 size_t hidden_size = input[1][1] / 3;
0020 if (fAttrLayout == 0) {
0021 size_t seq_length = input[0][0];
0022 size_t batch_size = input[0][1];
0023 std::vector<std::vector<size_t>> ret(
0024 {{seq_length, num_directions, batch_size, hidden_size},
0025 {num_directions, batch_size, hidden_size}});
0026 return ret;
0027 } else {
0028 size_t batch_size = input[0][0];
0029 size_t seq_length = input[0][1];
0030 std::vector<std::vector<size_t>> ret(
0031 {{batch_size, seq_length, num_directions, hidden_size},
0032 {batch_size, num_directions, hidden_size}});
0033 return ret;
0034 }
0035 }
0036
0037 template<typename T>
0038 void ROperator_GRU<T>::Initialize(RModel& model){
0039
0040 fUseSession = model.UseSession();
0041
0042 if (!model.CheckIfTensorAlreadyExist(fNX)) {
0043 throw std::runtime_error("TMVA SOFIE GRU Op input tensor " + fNX + " is not found in model.");
0044 }
0045 fShapeX = model.GetTensorShape(fNX);
0046 if (fShapeX.size() != 3) {
0047 throw std::runtime_error("TMVA SOFIE GRU Op input tensor " + fNX + " is not of 3 dimensions.");
0048 }
0049 if (!model.CheckIfTensorAlreadyExist(fNW)) {
0050 throw std::runtime_error("TMVA SOFIE GRU Op input tensor " + fNW + " is not found in model.");
0051 }
0052 fShapeW = model.GetTensorShape(fNW);
0053 if (fShapeW.size() != 3) {
0054 throw std::runtime_error("TMVA SOFIE GRU Op input tensor " + fNW + " is not of 3 dimensions.");
0055 }
0056 if (!model.CheckIfTensorAlreadyExist(fNR)) {
0057 throw std::runtime_error("TMVA SOFIE GRU Op input tensor " + fNR + " is not found in model.");
0058 }
0059 fShapeR = model.GetTensorShape(fNR);
0060 if (fShapeR.size() != 3) {
0061 throw std::runtime_error("TMVA SOFIE GRU Op input tensor " + fNR + " is not of 3 dimensions.");
0062 }
0063 if (!fNB.empty()) {
0064 if (!model.CheckIfTensorAlreadyExist(fNB)) {
0065 throw std::runtime_error("TMVA SOFIE GRU op input tensor " + fNB + " is not found in model.");
0066 }
0067 fShapeB = model.GetTensorShape(fNB);
0068 if (fShapeB.size() != 2 && fShapeB.size() != 4) {
0069 throw std::runtime_error("TMVA SOFIE GRU op input tensor " + fNB + " is not of 2 or 4 dimensions.");
0070 }
0071 if (fShapeB.size() == 2) {
0072
0073 auto original_data = model.GetInitializedTensorData(fNB);
0074 size_t num_directions = fShapeW[0];
0075 size_t batch_size = (fAttrLayout == 0) ? fShapeX[1] : fShapeX[0];
0076 size_t seq_length = (fAttrLayout == 0) ? fShapeX[0] : fShapeX[1];
0077 if (fType == "float") {
0078 float *original_bias = static_cast<float*>(original_data.get());
0079 float *new_bias = new float[num_directions * 6 * seq_length * batch_size * fAttrHiddenSize];
0080 for (size_t direction = 0; direction < num_directions; direction++) {
0081 for (size_t i = 0; i < 6; i++) {
0082 for (size_t seq = 0; seq < seq_length; seq++) {
0083 for (size_t batch = 0; batch < batch_size; batch++) {
0084 size_t bias_offset = direction * 6 * fAttrHiddenSize + i * fAttrHiddenSize;
0085 size_t offset = direction * 6 * batch_size * seq_length * fAttrHiddenSize +
0086 i * batch_size * seq_length * fAttrHiddenSize +
0087 + seq *batch_size *fAttrHiddenSize + batch *fAttrHiddenSize;
0088 std::copy(original_bias + bias_offset, original_bias + bias_offset + fAttrHiddenSize,
0089 new_bias + offset);
0090 }
0091 }
0092 }
0093 }
0094
0095 std::vector<size_t> new_bias_shape = {num_directions, 6, seq_length, batch_size, fAttrHiddenSize};
0096 std::shared_ptr<void> new_bias_ptr(new_bias, std::default_delete<float[]>());
0097 model.UpdateInitializedTensor(fNB, model.GetTensorType(fNB), new_bias_shape, new_bias_ptr);
0098 fShapeB = model.GetTensorShape(fNB);
0099 }
0100 }
0101 }
0102 if (!fNSequence_lens.empty()) {
0103 if (!model.CheckIfTensorAlreadyExist(fNSequence_lens)) {
0104 throw std::runtime_error("TMVA SOFIE GRU Op input tensor " +
0105 fNSequence_lens +
0106 "is not found in model.");
0107 }
0108 fShapeSequence_lens = model.GetTensorShape(fNSequence_lens);
0109 if (fShapeSequence_lens.size() != 1) {
0110 throw std::runtime_error("TMVA SOFIE GRU Op input tensor " +
0111 fNSequence_lens +
0112 " is not of 1 dimension.");
0113 }
0114 }
0115 if (!fNInitial_h.empty()) {
0116 if (!model.CheckIfTensorAlreadyExist(fNInitial_h)) {
0117 throw std::runtime_error("TMVA SOFIE GRU Op input tensor " +
0118 fNInitial_h + " is not found in model.");
0119 }
0120 fShapeInitial_h = model.GetTensorShape(fNInitial_h);
0121 if (fShapeInitial_h.size() != 3) {
0122 throw std::runtime_error("TMVA SOFIE GRU Op input tensor " +
0123 fNInitial_h + " is not of 3 dimensions.");
0124 }
0125 }
0126 if (!fNY.empty()) {
0127 fShapeY = ShapeInference({fShapeX, fShapeW})[0];
0128 if (!model.CheckIfTensorAlreadyExist(fNY)) {
0129 model.AddIntermediateTensor(fNY, model.GetTensorType(fNX), fShapeY);
0130 }
0131 }
0132 if (!fNY_h.empty()) {
0133 fShapeY_h = ShapeInference({fShapeX, fShapeW})[1];
0134 if (!model.CheckIfTensorAlreadyExist(fNY_h)) {
0135 model.AddIntermediateTensor(fNY_h, model.GetTensorType(fNX), fShapeY_h);
0136 }
0137 }
0138
0139 for (auto &activation : fAttrActivations) {
0140 if (activation != "Relu" && activation != "Tanh" &&
0141 activation != "Sigmoid" && activation != "Affine" &&
0142 activation != "LeakyRelu" && activation != "ThresholdRelu" &&
0143 activation != "ScaledTanh" && activation != "HardSigmoid" &&
0144 activation != "Elu" && activation != "Softsign" &&
0145 activation != "Softplus") {
0146 throw std::runtime_error("TMVA SOFIE - Activation function " +
0147 activation + " not implemented");
0148 }
0149 }
0150 if (fAttrDirection == "reverse") fAttrDirection = "backward";
0151 if (fAttrDirection != "forward" && fAttrDirection != "backward" &&
0152 fAttrDirection != "reverse" &&
0153 fAttrDirection != "bidirectional") {
0154 throw std::runtime_error(
0155 "TMVA SOFIE - Invalid GRU direction fAttrDirection = " +
0156 fAttrDirection);
0157 }
0158 if (3 * fAttrHiddenSize != fShapeW[1]) {
0159 throw std::runtime_error(
0160 "TMVA SOFIE - fAttrHiddenSize must be equal to " +
0161 std::to_string(fShapeW[1] / 3));
0162 }
0163 if (fAttrLayout > 1) {
0164 throw std::runtime_error("TMVA SOFIE - Layout fAttrLayout = " +
0165 std::to_string(fAttrLayout) +
0166 " must be 0 (timewise) or 1 (batchwise)");
0167 }
0168 if (fAttrLinearBeforeReset > 1) {
0169 throw std::runtime_error(
0170 "TMVA SOFIE - fAttrInputForget = " + std::to_string(fAttrLinearBeforeReset)
0171 + " must be 0 or 1.");
0172 }
0173 if (fAttrActivations.empty()) {
0174 if (fAttrDirection == "bidirectional") {
0175 fAttrActivations = {"Sigmoid", "Tanh", "Sigmoid", "Tanh"};
0176 } else {
0177 fAttrActivations = {"Sigmoid", "Tanh"};
0178 }
0179 }
0180 }
0181
0182
0183 template <typename T>
0184 std::string ROperator_GRU<T>::GenerateSessionMembersCode(std::string opName)
0185 {
0186 opName = "op_" + opName;
0187 std::stringstream out;
0188
0189 size_t num_directions = fShapeW[0];
0190 size_t seq_length = (fAttrLayout == 0) ? fShapeX[0] : fShapeX[1];
0191 size_t batch_size = (fAttrLayout == 0) ? fShapeX[1] : fShapeX[0];
0192 size_t input_size = fShapeX[2];
0193
0194 if (fAttrLayout != 0) {
0195 out << "std::vector<" << fType << "> fVec_" << opName << "_input = std::vector<" << fType << ">("
0196 << seq_length * batch_size * input_size << ");\n";
0197 out << "std::vector<" << fType << "> fVec_" << opName << "_initial_hidden_state = std::vector<" << fType << ">("
0198 << num_directions * batch_size * fAttrHiddenSize << ");\n";
0199 out << "std::vector<" << fType << "> fVec_" << opName << "_initial_cell_state = std::vector<" << fType << ">("
0200 << num_directions * batch_size * fAttrHiddenSize << ");\n";
0201 }
0202
0203 size_t ff_size = seq_length * batch_size * fAttrHiddenSize;
0204 out << "std::vector<" << fType << "> fVec_" << opName << "_f_update_gate = std::vector<" << fType << ">(" << ff_size << ");\n";
0205 out << "std::vector<" << fType << "> fVec_" << opName << "_f_reset_gate = std::vector<" << fType << ">(" << ff_size << ");\n";
0206 out << "std::vector<" << fType << "> fVec_" << opName << "_f_hidden_gate = std::vector<" << fType << ">(" << ff_size << ");\n";
0207
0208 size_t hs_size = seq_length * num_directions * batch_size * fAttrHiddenSize;
0209 out << "std::vector<" << fType << "> fVec_" << opName << "_update_gate = std::vector<" << fType << ">(" << hs_size << ");\n";
0210 out << "std::vector<" << fType << "> fVec_" << opName << "_reset_gate = std::vector<" << fType << ">(" << hs_size << ");\n";
0211 out << "std::vector<" << fType << "> fVec_" << opName << "_hidden_gate = std::vector<" << fType << ">(" << hs_size << ");\n";
0212
0213
0214 out << "std::vector<" << fType << "> fVec_" << opName << "_feedback = std::vector<" << fType << ">("
0215 << batch_size * fAttrHiddenSize << ");\n";
0216
0217
0218 if (fAttrLayout != 0 || fNY.empty()) {
0219 out << "std::vector<" << fType << "> fVec_" << opName << "_hidden_state = std::vector<" << fType << ">(" << hs_size << ");\n";
0220 }
0221
0222 out << "\n";
0223
0224 return out.str();
0225 }
0226
0227
0228 template<typename T>
0229 auto ROperator_GRU<T>::Generate(std::string OpName)
0230 -> std::string {
0231 OpName = "op_" + OpName;
0232 std::stringstream out;
0233
0234 size_t seq_length = (fAttrLayout == 0) ? fShapeX[0] : fShapeX[1];
0235 size_t batch_size = (fAttrLayout == 0) ? fShapeX[1] : fShapeX[0];
0236 size_t input_size = fShapeX[2];
0237 size_t num_directions = fShapeW[0];
0238
0239
0240 if (fAttrLayout == 0) {
0241 out << SP << fType << " *" << OpName << "_input = tensor_" << fNX << ";\n";
0242 } else {
0243 if (fUseSession) {
0244 out << SP << fType << " * " << OpName << "_input = fVec_" << OpName << "_input.data();\n";
0245 } else {
0246 out << SP << fType << " " << OpName << "_input[" << seq_length * batch_size * input_size << "];\n";
0247 }
0248 out << SP << "for(size_t seq = 0; seq < " << seq_length << "; seq++) {\n";
0249 out << SP << SP << "for(size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
0250 out << SP << SP << SP << "for(size_t i = 0; i < " << input_size << "; i++) {\n";
0251 out << SP << SP << SP << SP << OpName << "_input[seq * " << batch_size * input_size
0252 << " + batch * " << input_size << " + i] = " << "tensor_" << fNX << "[batch * "
0253 << seq_length * input_size << " + seq * " << input_size << " + i];\n";
0254 out << SP << SP << SP << "}\n";
0255 out << SP << SP << "}\n";
0256 out << SP << "}\n";
0257 }
0258
0259
0260 if (!fNInitial_h.empty()) {
0261 if (fAttrLayout == 0) {
0262 out << SP << fType << " *" << OpName << "_initial_hidden_state = " << " tensor_"
0263 << fNInitial_h << ";\n";
0264 } else {
0265 if (fUseSession) {
0266 out << SP << fType << " * " << OpName << "_initial_hidden_state = fVec_" << OpName
0267 << "_initial_hidden_state.data();\n";
0268 } else {
0269 out << SP << fType << " " << OpName << "_initial_hidden_state[" << num_directions * batch_size *
0270 fAttrHiddenSize << "];\n";
0271 }
0272 for (size_t direction = 0; direction < num_directions; direction++) {
0273 out << SP << "for(size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
0274 out << SP << SP << "for(size_t h = 0; h < " << fAttrHiddenSize << "; h++) {\n";
0275 out << SP << SP << SP << OpName << "_initial_hidden_state["
0276 << direction * batch_size * fAttrHiddenSize << " + batch * " << fAttrHiddenSize
0277 << " + h] = tensor_" << fNInitial_h << "[batch * " << num_directions * fAttrHiddenSize
0278 << " + " << direction * fAttrHiddenSize << " + h];\n";
0279 out << SP << SP << "}\n";
0280 out << SP << "}\n";
0281 }
0282 }
0283 }
0284
0285
0286 size_t feedforward_size = seq_length * batch_size * fAttrHiddenSize;
0287 if (fUseSession) {
0288 out << SP << fType << " * " << OpName << "_f_update_gate = fVec_" << OpName << "_f_update_gate.data();\n";
0289 out << SP << fType << " * " << OpName << "_f_reset_gate = fVec_" << OpName << "_f_reset_gate.data();\n";
0290 out << SP << fType << " * " << OpName << "_f_hidden_gate = fVec_" << OpName << "_f_hidden_gate.data();\n";
0291 } else {
0292 out << SP << fType << " " << OpName << "_f_update_gate[" << feedforward_size << "] = {0};\n";
0293 out << SP << fType << " " << OpName << "_f_reset_gate[" << feedforward_size << "] = {0};\n";
0294 out << SP << fType << " " << OpName << "_f_hidden_gate[" << feedforward_size << "] = {0};\n";
0295 }
0296
0297 size_t hidden_state_size = seq_length * num_directions * batch_size * fAttrHiddenSize;
0298 if (fUseSession) {
0299 out << SP << fType << " * " << OpName << "_update_gate = fVec_" << OpName << "_update_gate.data();\n";
0300 out << SP << fType << " * " << OpName << "_reset_gate = fVec_" << OpName << "_reset_gate.data();\n";
0301 out << SP << fType << " * " << OpName << "_hidden_gate = fVec_" << OpName << "_hidden_gate.data();\n";
0302 } else {
0303 out << SP << fType << " " << OpName << "_update_gate[" << hidden_state_size << "] = {0};\n";
0304 out << SP << fType << " " << OpName << "_reset_gate[" << hidden_state_size << "] = {0};\n";
0305 out << SP << fType << " " << OpName << "_hidden_gate[" << hidden_state_size << "] = {0};\n";
0306 }
0307
0308 if (fAttrLayout == 0 && !fNY.empty()) {
0309 out << SP << fType << " *" << OpName << "_hidden_state = tensor_" << fNY << ";\n";
0310 } else {
0311 if (fUseSession) {
0312 out << SP << fType << " * " << OpName << "_hidden_state = fVec_" << OpName << "_hidden_state.data();\n";
0313 } else {
0314 out << SP << fType << " " << OpName << "_hidden_state[" << hidden_state_size << "] = {0};\n";
0315 }
0316 }
0317
0318 if (fUseSession) {
0319 out << SP << fType << " * " << OpName << "_feedback = fVec_" << OpName << "_feedback.data();\n";
0320 } else {
0321 out << SP << fType << " " << OpName << "_feedback[" << batch_size * fAttrHiddenSize << "] = {0};\n";
0322 }
0323
0324 out << SP << "char " << OpName << "_transA = 'N';\n";
0325 out << SP << "char " << OpName << "_transB = 'T';\n";
0326 out << SP << "int " << OpName << "_m = " << seq_length * batch_size << ";\n";
0327 out << SP << "int " << OpName << "_m2 = " << batch_size << ";\n";
0328 out << SP << "int " << OpName << "_n = " << fAttrHiddenSize << ";\n";
0329 out << SP << "int " << OpName << "_k = " << input_size << ";\n";
0330 if (fType == "float") {
0331 out << SP << "float " << OpName << "_alpha = 1.;\n";
0332 out << SP << "float " << OpName << "_beta = 0.;\n";
0333 }
0334 if (!fNB.empty()) {
0335 out << SP << "int " << OpName << "_bias_size = " << seq_length * batch_size * fAttrHiddenSize << ";\n";
0336 }
0337 out << SP << "int " << OpName << "_incx = 1;\n";
0338 out << SP << "int " << OpName << "_incy = 1;\n";
0339 out << SP << "int " << OpName << "_feedback_size = " << batch_size * fAttrHiddenSize << ";\n";
0340
0341 for (size_t direction = 0; direction < num_directions; direction++) {
0342 if (direction == 0) {
0343 if (fType == "float") {
0344
0345 out << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
0346 << OpName <<"_n, &" << OpName << "_m, &" << OpName << "_k, &" << OpName << "_alpha, tensor_"
0347 << fNW << ", &" << OpName << "_k, " << OpName << "_input, &" << OpName << "_k, &"
0348 << OpName << "_beta, " << OpName << "_f_update_gate, &" << OpName << "_n);\n";
0349
0350 size_t wr_offset = fAttrHiddenSize * input_size;
0351 out << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
0352 << OpName <<"_n, &" << OpName << "_m, &" << OpName << "_k, &" << OpName << "_alpha, tensor_"
0353 << fNW << " + " << wr_offset << ", &" << OpName << "_k, " << OpName << "_input, &"
0354 << OpName << "_k, &" << OpName << "_beta, " << OpName << "_f_reset_gate, &" << OpName << "_n);\n";
0355
0356 size_t wh_offset = 2 * fAttrHiddenSize * input_size;
0357 out << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
0358 << OpName <<"_n, &" << OpName << "_m, &" << OpName << "_k, &" << OpName << "_alpha, tensor_"
0359 << fNW << " + " << wh_offset << ", &" << OpName << "_k, " << OpName << "_input, &"
0360 << OpName << "_k, &" << OpName << "_beta, " << OpName << "_f_hidden_gate, &" << OpName << "_n);\n";
0361 }
0362 } else {
0363 if (fType == "float") {
0364
0365 size_t wz_offset = 3 * fAttrHiddenSize * input_size;
0366 out << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
0367 << OpName <<"_n, &" << OpName << "_m, &" << OpName << "_k, &" << OpName << "_alpha, tensor_"
0368 << fNW << " + " << wz_offset << ", &" << OpName << "_k, " << OpName << "_input, &"
0369 << OpName << "_k, &" << OpName << "_beta, " << OpName << "_f_update_gate, &" << OpName << "_n);\n";
0370
0371 size_t wr_offset = 3 * fAttrHiddenSize * input_size + fAttrHiddenSize * input_size;
0372 out << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
0373 << OpName <<"_n, &" << OpName << "_m, &" << OpName << "_k, &" << OpName << "_alpha, tensor_"
0374 << fNW << " + " << wr_offset << ", &" << OpName << "_k, " << OpName << "_input, &"
0375 << OpName << "_k, &" << OpName << "_beta, " << OpName << "_f_reset_gate, &" << OpName << "_n);\n";
0376
0377 size_t wh_offset = 3 * fAttrHiddenSize * input_size + 2 * fAttrHiddenSize * input_size;
0378 out << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
0379 << OpName <<"_n, &" << OpName << "_m, &" << OpName << "_k, &" << OpName << "_alpha, tensor_"
0380 << fNW << " + " << wh_offset << ", &" << OpName << "_k, " << OpName << "_input, &"
0381 << OpName << "_k, &" << OpName << "_beta, " << OpName << "_f_hidden_gate, &" << OpName << "_n);\n";
0382 }
0383 }
0384
0385 if (!fNB.empty()) {
0386 if (direction == 0) {
0387 if (fType == "float") {
0388
0389 out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_"
0390 << fNB << ", &" << OpName << "_incx, " << OpName << "_f_update_gate, &" << OpName << "_incy);\n";
0391
0392 size_t rbz_offset = 3 * batch_size * seq_length * fAttrHiddenSize;
0393 out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_"
0394 << fNB << " + " << rbz_offset << ", &" << OpName << "_incx, " << OpName << "_f_update_gate, &"
0395 << OpName << "_incy);\n";
0396
0397 size_t wbr_offset = batch_size * seq_length * fAttrHiddenSize;
0398 out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_"
0399 << fNB << " + " << wbr_offset << ", &" << OpName << "_incx, " << OpName << "_f_reset_gate, &"
0400 << OpName << "_incy);\n";
0401
0402
0403 size_t rbr_offset = 4 * batch_size * seq_length * fAttrHiddenSize;
0404 out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_"
0405 << fNB << " + " << rbr_offset << ", &" << OpName << "_incx, " << OpName << "_f_reset_gate, &"
0406 << OpName << "_incy);\n";
0407
0408 size_t wbh_offset = 2 * batch_size * seq_length * fAttrHiddenSize;
0409 out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_"
0410 << fNB << " + " << wbh_offset << ", &" << OpName << "_incx, " << OpName << "_f_hidden_gate, &"
0411 << OpName << "_incy);\n";
0412 if (fAttrLinearBeforeReset == 0) {
0413
0414 size_t rbh_offset = 5 * batch_size * seq_length * fAttrHiddenSize;
0415 out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_"
0416 << fNB << " + " << rbh_offset << ", &" << OpName << "_incx, " << OpName
0417 << "_f_hidden_gate, &" << OpName << "_incy);\n";
0418 }
0419 }
0420 } else {
0421 if (fType == "float") {
0422
0423 size_t wbz_offset = 6 * batch_size * seq_length * fAttrHiddenSize;
0424 out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_"
0425 << fNB << " + " << wbz_offset << ", &" << OpName << "_incx, " << OpName << "_f_update_gate, &"
0426 << OpName << "_incy);\n";
0427
0428
0429 size_t rbz_offset = 9 * batch_size * seq_length * fAttrHiddenSize;
0430 out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_"
0431 << fNB << " + " << rbz_offset << ", &" << OpName << "_incx, " << OpName << "_f_update_gate, &"
0432 << OpName << "_incy);\n";
0433
0434 size_t wbr_offset = 7 * batch_size * seq_length * fAttrHiddenSize;
0435 out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_"
0436 << fNB << " + " << wbr_offset << ", &" << OpName << "_incx, " << OpName << "_f_reset_gate, &"
0437 << OpName << "_incy);\n";
0438
0439 size_t rbr_offset = 10 * batch_size * seq_length * fAttrHiddenSize;
0440 out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_"
0441 << fNB << " + " << rbr_offset << ", &" << OpName << "_incx, " << OpName << "_f_reset_gate, &"
0442 << OpName << "_incy);\n";
0443
0444 size_t wbh_offset = 8 * batch_size * seq_length * fAttrHiddenSize;
0445 out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_"
0446 << fNB << " + " << wbh_offset << ", &" << OpName << "_incx, " << OpName << "_f_hidden_gate, &"
0447 << OpName << "_incy);\n";
0448 if (fAttrLinearBeforeReset == 0) {
0449
0450 size_t rbh_offset = 11 * batch_size * seq_length * fAttrHiddenSize;
0451 out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_"
0452 << fNB << " + " << rbh_offset << ", &" << OpName << "_incx, " << OpName
0453 << "_f_hidden_gate, &" << OpName << "_incy);\n";
0454 }
0455 }
0456 }
0457 }
0458
0459
0460 out << SP << "for (size_t seq = 0; seq < " << seq_length << "; seq++) {\n";
0461 out << SP << SP << "size_t offset = seq * " << batch_size * fAttrHiddenSize << ";\n";
0462 if (direction == 0) {
0463 out << SP << SP << "size_t gate_offset = seq * " << num_directions * batch_size * fAttrHiddenSize
0464 << ";\n";
0465 } else {
0466 out << SP << SP << "size_t gate_offset = seq * " << num_directions * batch_size * fAttrHiddenSize
0467 << " + " << batch_size * fAttrHiddenSize << ";\n";
0468 }
0469 size_t f_seq_size = batch_size * fAttrHiddenSize;
0470 out << SP << SP << "std::copy(" << OpName << "_f_update_gate + offset, " << OpName
0471 << "_f_update_gate + offset + " << f_seq_size << ", " << OpName << "_update_gate + gate_offset);\n";
0472 out << SP << SP << "std::copy(" << OpName << "_f_reset_gate + offset, " << OpName
0473 << "_f_reset_gate + offset + " << f_seq_size << ", " << OpName << "_reset_gate + gate_offset);\n";
0474 out << SP << SP << "std::copy(" << OpName << "_f_hidden_gate + offset, " << OpName
0475 << "_f_hidden_gate + offset + " << f_seq_size << ", " << OpName << "_hidden_gate + gate_offset);\n";
0476 out << SP << "}\n";
0477
0478 out << SP << "for (size_t seq = 0; seq < " << seq_length << "; seq++) {\n";
0479 if (fAttrDirection == "backward" || direction == 1) {
0480 out << SP << SP << "size_t index = " << seq_length - 1 << " - seq;\n";
0481 } else {
0482 out << SP << SP << "size_t index = seq;\n";
0483 }
0484 out << SP << SP << "int m2 = " << batch_size << ";\n";
0485 if (direction == 0) {
0486 out << SP << SP << "size_t offset = index * " << num_directions * batch_size * fAttrHiddenSize
0487 << ";\n";
0488 } else {
0489 out << SP << SP << "size_t offset = index * " << num_directions * batch_size * fAttrHiddenSize
0490 << " + " << batch_size * fAttrHiddenSize << ";\n";
0491 }
0492 size_t size = batch_size * fAttrHiddenSize;
0493
0494 out << SP << SP << "if (seq == 0) {\n";
0495 if (!fNInitial_h.empty()) {
0496 if (direction == 0) {
0497 if (fType == "float") {
0498 out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
0499 << OpName << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << ", &"
0500 << OpName << "_n, " << OpName << "_initial_hidden_state, &" << OpName << "_n, &" << OpName
0501 << "_alpha, " << OpName << "_update_gate + offset, &" << OpName << "_n);\n";
0502 size_t rr_offset = fAttrHiddenSize * fAttrHiddenSize;
0503 out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
0504 << OpName << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << " + "
0505 << rr_offset << ", &" << OpName << "_n, " << OpName << "_initial_hidden_state, &" << OpName
0506 << "_n, &" << OpName << "_alpha, " << OpName << "_reset_gate + offset, &" << OpName << "_n);\n";
0507 }
0508 } else {
0509 if (fType == "float") {
0510 size_t rz_offset = 3 * fAttrHiddenSize * fAttrHiddenSize;
0511 out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
0512 << OpName << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << " + "
0513 << rz_offset << ", &" << OpName << "_n, " << OpName << "_initial_hidden_state, &" << OpName
0514 << "_n, &" << OpName << "_alpha, " << OpName << "_update_gate + offset, &" << OpName << "_n);\n";
0515 size_t rr_offset = 4 * fAttrHiddenSize * fAttrHiddenSize;
0516 out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
0517 << OpName << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << " + "
0518 << rr_offset << ", &" << OpName << "_n, " << OpName << "_initial_hidden_state, &" << OpName
0519 << "_n, &" << OpName << "_alpha, " << OpName << "_reset_gate + offset, &" << OpName << "_n);\n";
0520 }
0521 }
0522 }
0523 out << SP << SP << "} else {\n";
0524
0525 if (direction == 0) {
0526 if (fAttrDirection == "backward") {
0527 out << SP << SP << SP << "size_t previous_offset = (index + 1) * "
0528 << num_directions * batch_size * fAttrHiddenSize << ";\n";
0529 } else {
0530 out << SP << SP << SP << "size_t previous_offset = (seq - 1) * "
0531 << num_directions * batch_size * fAttrHiddenSize << ";\n";
0532 }
0533 if (fType == "float") {
0534 out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
0535 << OpName << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << ", &"
0536 << OpName << "_n, " << OpName << "_hidden_state + previous_offset, &" << OpName << "_n, &"
0537 << OpName << "_alpha, " << OpName << "_update_gate + offset, &" << OpName << "_n);\n";
0538 size_t rr_offset = fAttrHiddenSize * fAttrHiddenSize;
0539 out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
0540 << OpName << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << " + "
0541 << rr_offset << ", &" << OpName << "_n, " << OpName << "_hidden_state + previous_offset, &"
0542 << OpName << "_n, &" << OpName << "_alpha, " << OpName << "_reset_gate + offset, &"
0543 << OpName << "_n);\n";
0544 }
0545 } else {
0546 out << SP << SP << SP << "size_t previous_offset = (index + 1) * "
0547 << num_directions * batch_size * fAttrHiddenSize << " + " << batch_size * fAttrHiddenSize << ";\n";
0548 if (fType == "float") {
0549 size_t rz_offset = 3 * fAttrHiddenSize * fAttrHiddenSize;
0550 out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
0551 << OpName << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << " + "
0552 << rz_offset << ", &" << OpName << "_n, " << OpName << "_hidden_state + previous_offset, &"
0553 << OpName << "_n, &" << OpName << "_alpha, " << OpName << "_update_gate + offset, &"
0554 << OpName << "_n);\n";
0555 size_t rr_offset = 4 * fAttrHiddenSize * fAttrHiddenSize;
0556 out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
0557 << OpName << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << " + "
0558 << rr_offset << ", &" << OpName << "_n, " << OpName << "_hidden_state + previous_offset, &"
0559 << OpName << "_n, &" << OpName << "_alpha, " << OpName << "_reset_gate + offset, &"
0560 << OpName << "_n);\n";
0561 }
0562 }
0563 out << SP << SP << "}\n";
0564
0565
0566 if (fAttrClip > .0) {
0567 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0568 if (fType == "float") {
0569 out << SP << SP << SP << "float z = (" << OpName << "_update_gate[i] > " << -fAttrClip
0570 << ") ? " << OpName << "_update_gate[i] : " << -fAttrClip << ";\n";
0571 }
0572 out << SP << SP << SP << OpName << "_update_gate[i] = (z < " << fAttrClip
0573 << ") ? z : " << fAttrClip << ";\n";
0574 if (fType == "float") {
0575 out << SP << SP << SP << "float r = (" << OpName << "_reset_gate[i] > " << -fAttrClip
0576 << ") ? " << OpName << "_reset_gate[i] : " << -fAttrClip << ";\n";
0577 }
0578 out << SP << SP << SP << OpName << "_reset_gate[i] = (r < " << fAttrClip
0579 << ") ? r : " << fAttrClip << ";\n";
0580 out << SP << SP << "}\n";
0581 }
0582
0583
0584 if (fAttrActivations[direction * 2] == "Relu") {
0585 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0586 out << SP << SP << SP << "if (" << OpName << "_update_gate[i] < 0.)\n";
0587 out << SP << SP << SP << SP << OpName << "_update_gate[i] = 0.;\n";
0588 out << SP << SP << SP << "if (" << OpName << "_reset_gate[i] < 0.)\n";
0589 out << SP << SP << SP << SP << OpName << "_reset_gate[i] = 0.;\n";
0590 out << SP << SP << "}\n";
0591 } else if (fAttrActivations[direction * 2] == "Tanh") {
0592 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0593 if (fType == "float") {
0594 out << SP << SP << SP << "float z = exp(-2 * " << OpName << "_update_gate[i]);\n";
0595 }
0596 out << SP << SP << SP << SP << OpName << "_update_gate[i] = (1. - z) / (1. + z);\n";
0597 if (fType == "float") {
0598 out << SP << SP << SP << "float r = exp(-2 * " << OpName << "_reset_gate[i]);\n";
0599 }
0600 out << SP << SP << SP << SP << OpName << "_reset_gate[i] = (1. - r) / (1. + r);\n";
0601 out << SP << SP << "}\n";
0602 } else if (fAttrActivations[direction * 2] == "Sigmoid") {
0603 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0604 out << SP << SP << SP << SP << OpName << "_update_gate[i] = 1. / (1. + exp(-"
0605 << OpName << "_update_gate[i]));\n";
0606 out << SP << SP << SP << SP << OpName << "_reset_gate[i] = 1. / (1. + exp(-"
0607 << OpName << "_reset_gate[i]));\n";
0608 out << SP << SP << "}\n";
0609 } else if (fAttrActivations[direction * 2] == "Affine") {
0610 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0611 out << SP << SP << SP << SP << OpName << "_update_gate[i] = "
0612 << fAttrActivationAlpha[direction * 2] << " * " << OpName << "_update_gate[i] + "
0613 << fAttrActivationBeta[direction * 2] << ";\n";
0614 out << SP << SP << SP << SP << OpName << "_reset_gate[i] = "
0615 << fAttrActivationAlpha[direction * 2] << " * " << OpName << "_reset_gate[i] + "
0616 << fAttrActivationBeta[direction * 2] << ";\n";
0617 out << SP << SP << "}\n";
0618 } else if (fAttrActivations[direction * 2] == "ScaledTanh") {
0619 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0620 if (fType == "float") {
0621 out << SP << SP << SP << "float z = exp(-2 * " << fAttrActivationBeta[direction * 2]
0622 << " * "<< OpName << "_update_gate[i]);\n";
0623 }
0624 out << SP << SP << SP << SP << OpName << "_update_gate[i] = "
0625 << fAttrActivationAlpha[direction * 2] << " * (1. - z) / (1. + z);\n";
0626 if (fType == "float") {
0627 out << SP << SP << SP << "float r = exp(-2 * " << fAttrActivationBeta[direction * 2]
0628 << " * "<< OpName << "_reset_gate[i]);\n";
0629 }
0630 out << SP << SP << SP << SP << OpName << "_reset_gate[i] = "
0631 << fAttrActivationAlpha[direction * 2] << " * (1. - r) / (1. + r);\n";
0632 out << SP << SP << "}\n";
0633 } else if (fAttrActivations[direction * 2] == "HardSigmoid") {
0634 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0635 if (fType == "float") {
0636 out << SP << SP << SP << "float za = " << fAttrActivationAlpha[direction * 2] << " * "
0637 << OpName << "_update_gate[i] + " << fAttrActivationBeta[direction * 2] << ";\n";
0638 out << SP << SP << SP << "float zb = (za > 0.) ? za : 0.;\n";
0639 }
0640 out << SP << SP << SP << SP << OpName << "_update_gate[i] = (zb < 1.) ? zb : 1.;\n";
0641 if (fType == "float") {
0642 out << SP << SP << SP << "float ra = " << fAttrActivationAlpha[direction * 2] << " * "
0643 << OpName << "_reset_gate[i] + " << fAttrActivationBeta[direction * 2] << ";\n";
0644 out << SP << SP << SP << "float rb = (ra > 0.) ? ra : 0.;\n";
0645 }
0646 out << SP << SP << SP << SP << OpName << "_reset_gate[i] = (rb < 1.) ? rb : 1.;\n";
0647 out << SP << SP << "}\n";
0648 } else if (fAttrActivations[direction * 2] == "LeakyRelu") {
0649 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0650 out << SP << SP << SP << "if (" << OpName << "_update_gate[i] < 0.)\n";
0651 out << SP << SP << SP << SP << OpName << "_update_gate[i] = "
0652 << fAttrActivationAlpha[direction * 2] << " * " << OpName << "_update_gate[i];\n";
0653 out << SP << SP << SP << "if (" << OpName << "_reset_gate[i] < 0.)\n";
0654 out << SP << SP << SP << SP << OpName << "_reset_gate[i] = "
0655 << fAttrActivationAlpha[direction * 2] << " * " << OpName << "_reset_gate[i];\n";
0656 out << SP << SP << "}\n";
0657 } else if (fAttrActivations[direction * 2] == "ThresholdRelu") {
0658 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0659 out << SP << SP << SP << "if (" << OpName << "_update_gate[i] < "
0660 << fAttrActivationAlpha[direction * 2] << ")\n";
0661 out << SP << SP << SP << SP << OpName << "_update_gate[i] = 0.;\n";
0662 out << SP << SP << SP << "if (" << OpName << "_reset_gate[i] < "
0663 << fAttrActivationAlpha[direction * 2] << ")\n";
0664 out << SP << SP << SP << SP << OpName << "_reset_gate[i] = 0.;\n";
0665 out << SP << SP << "}";
0666 } else if (fAttrActivations[direction * 2] == "Elu") {
0667 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0668 out << SP << SP << SP << "if (" << OpName << "_update_gate[i] < 0.)\n";
0669 out << SP << SP << SP << SP << OpName << "_update_gate[i] = "
0670 << fAttrActivationAlpha[direction * 2] << " * exp(" << OpName << "_update_gate[i] - 1.);\n";
0671 out << SP << SP << SP << "if (" << OpName << "_reset_gate[i] < 0.)\n";
0672 out << SP << SP << SP << SP << OpName << "_reset_gate[i] = "
0673 << fAttrActivationAlpha[direction * 2] << " * exp(" << OpName << "_reset_gate[i] - 1.);\n";
0674 out << SP << SP << "}\n";
0675 } else if (fAttrActivations[direction * 2] == "Softsign") {
0676 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0677 out << SP << SP << SP << SP << OpName << "_update_gate[i] = " << OpName
0678 << "_update_gate[i] / (1. + abs(" << OpName << "_update_gate[i]));\n";
0679 out << SP << SP << SP << SP << OpName << "_reset_gate[i] = " << OpName
0680 << "_reset_gate[i] / (1. + abs(" << OpName << "_reset_gate[i]));\n";
0681 out << SP << SP << "}\n";
0682 } else {
0683 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0684 out << SP << SP << SP << SP << OpName << "_update_gate[i] = log(1. + exp("
0685 << OpName << "_update_gate[i]));\n";
0686 out << SP << SP << SP << SP << OpName << "_reset_gate[i] = log(1. + exp("
0687 << OpName << "_reset_gate[i]));\n";
0688 out << SP << SP << "}\n";
0689 }
0690
0691 if (fAttrLinearBeforeReset == 0) {
0692 out << SP << SP << "if (seq == 0) {\n";
0693 if (!fNInitial_h.empty()) {
0694
0695 out << SP << SP << SP << "for (size_t i = 0; i < " << size << "; i++) {\n";
0696 out << SP << SP << SP << SP << OpName << "_feedback[i] = " << OpName
0697 << "_reset_gate[i + offset] * " << OpName << "_initial_hidden_state[i];\n";
0698 out << SP << SP << SP << "}\n";
0699 }
0700 out << SP << SP << "} else {\n";
0701
0702 if (direction == 0) {
0703 if (fAttrDirection == "backward") {
0704 out << SP << SP << SP << "size_t previous_offset = (index + 1) * "
0705 << num_directions * batch_size * fAttrHiddenSize << ";\n";
0706 } else {
0707 out << SP << SP << SP << "size_t previous_offset = (seq - 1) * "
0708 << num_directions * batch_size * fAttrHiddenSize << ";\n";
0709 }
0710 } else {
0711 out << SP << SP << SP << "size_t previous_offset = (index + 1) * " << num_directions
0712 * batch_size * fAttrHiddenSize << " + " << batch_size * fAttrHiddenSize << ";\n";
0713 }
0714 out << SP << SP << SP << "for (size_t i = 0; i < " << size << "; i++) {\n";
0715 out << SP << SP << SP << SP << OpName << "_feedback[i] = " << OpName
0716 << "_reset_gate[i + offset] * " << OpName << "_hidden_state[i + previous_offset];\n";
0717 out << SP << SP << SP << "}\n";
0718 out << SP << SP << "}\n";
0719
0720 size_t rh_offset = (direction == 0) ?
0721 2 * fAttrHiddenSize * fAttrHiddenSize : 3 * fAttrHiddenSize * fAttrHiddenSize
0722 + 2 * fAttrHiddenSize * fAttrHiddenSize;
0723 out << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
0724 << OpName << "_n, &" << OpName << "_m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_"
0725 << fNR << " + " << rh_offset << ", &" << OpName << "_n, " << OpName << "_feedback, &" << OpName
0726 << "_n, &" << OpName << "_beta, " << OpName << "_feedback, &" << OpName << "_n);\n";
0727 } else {
0728
0729
0730 size_t rh_offset = (direction == 0)
0731 ? 2 * fAttrHiddenSize * fAttrHiddenSize
0732 : 3 * fAttrHiddenSize * fAttrHiddenSize + 2 * fAttrHiddenSize * fAttrHiddenSize;
0733 out << SP << SP << "if (seq == 0) {\n";
0734 if (!fNInitial_h.empty()) {
0735
0736 out << SP << SP << SP
0737 << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &" << OpName << "_n, &"
0738 << OpName << "_m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << " + "
0739 << rh_offset << ", &" << OpName << "_n, " << OpName << "_initial_hidden_state, &"
0740 << OpName << "_n, &" << OpName << "_beta, " << OpName << "_feedback, &" << OpName << "_n);\n";
0741 }
0742 out << SP << SP << "} else {\n";
0743
0744 if (direction == 0) {
0745 if (fAttrDirection == "backward") {
0746 out << SP << SP << SP << "size_t previous_offset = (index + 1) * "
0747 << num_directions * batch_size * fAttrHiddenSize << ";\n";
0748 } else {
0749 out << SP << SP << SP << "size_t previous_offset = (seq - 1) * "
0750 << num_directions * batch_size * fAttrHiddenSize << ";\n";
0751 }
0752 } else {
0753 out << SP << SP << SP << "size_t previous_offset = (index + 1) * " << num_directions
0754 * batch_size * fAttrHiddenSize << " + " << batch_size * fAttrHiddenSize << ";\n";
0755 }
0756 out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
0757 << OpName << "_n, &" << OpName << "_m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR
0758 << " + " << rh_offset << ", &" << OpName << "_n, " << OpName << "_hidden_state + previous_offset, &"
0759 << OpName << "_n, &" << OpName << "_beta, " << OpName << "_feedback, &" << OpName << "_n);\n";
0760
0761 out << SP << SP << "}\n";
0762
0763 if (!fNB.empty()) {
0764 size_t rbh_offset = (direction == 0) ? 5 * batch_size * seq_length * fAttrHiddenSize
0765 : 11 * batch_size * seq_length * fAttrHiddenSize;
0766 out << SP << SP << "BLAS::saxpy_(&" << OpName << "_feedback_size, &" << OpName
0767 << "_alpha, tensor_" << fNB << " + " << rbh_offset << ", &" << OpName << "_incx, "
0768 << OpName << "_feedback, &" << OpName << "_incy);\n";
0769 }
0770
0771 out << SP << SP << "for (size_t i = 0; i < " << size << "; i++) {\n";
0772 out << SP << SP << SP << OpName << "_feedback[i] *= " << OpName << "_reset_gate[i + offset];\n";
0773 out << SP << SP << "}\n";
0774 }
0775
0776
0777 out << SP << SP << "BLAS::saxpy_(&" << OpName << "_feedback_size, &" << OpName << "_alpha, "
0778 << OpName << "_feedback, &" << OpName << "_incx, " << OpName << "_hidden_gate + offset, &"
0779 << OpName << "_incy);\n";
0780
0781
0782 if (fAttrClip > .0) {
0783 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0784 if (fType == "float") {
0785 out << SP << SP << SP << "float x = (" << OpName << "_hidden_gate[i] > " << -fAttrClip
0786 << ") ? " << OpName << "_hidden_gate[i] : " << -fAttrClip << ";\n";
0787 }
0788 out << SP << SP << SP << OpName << "_hidden_gate[i] = (x < " << fAttrClip << ") ? x : "
0789 << fAttrClip << ";\n";
0790 out << SP << SP << "}\n";
0791 }
0792
0793
0794 if (fAttrActivations[direction * 2 + 1] == "Relu") {
0795 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0796 out << SP << SP << SP << "if (" << OpName << "_hidden_gate[i] < 0.)\n";
0797 out << SP << SP << SP << SP << OpName << "_hidden_gate[i] = 0.;\n";
0798 out << SP << SP << "}\n";
0799 } else if (fAttrActivations[direction * 2 + 1] == "Tanh") {
0800 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0801 if (fType == "float") {
0802 out << SP << SP << SP << "float ex = exp(-2 * " << OpName << "_hidden_gate[i]);\n";
0803 }
0804 out << SP << SP << SP << SP << OpName << "_hidden_gate[i] = (1. - ex) / (1. + ex);\n";
0805 out << SP << SP << "}\n";
0806 } else if (fAttrActivations[direction * 2 + 1] == "Sigmoid") {
0807 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0808 out << SP << SP << SP << SP << OpName << "_hidden_gate[i] = 1. / (1. + exp(-" << OpName
0809 << "_hidden_gate[i]));\n";
0810 out << SP << SP << "}\n";
0811 } else if (fAttrActivations[direction * 2 + 1] == "Affine") {
0812 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0813 out << SP << SP << SP << SP << OpName << "_hidden_gate[i] = "
0814 << fAttrActivationAlpha[direction * 2 + 1] << " * " << OpName << "_hidden_gate[i] + "
0815 << fAttrActivationBeta[direction * 2 + 1] << ";\n";
0816 out << SP << SP << "}\n";
0817 } else if (fAttrActivations[direction * 2 + 1] == "ScaledTanh") {
0818 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0819 if (fType == "float") {
0820 out << SP << SP << SP << "float ex = exp(-2 * " << fAttrActivationBeta[direction * 2 + 1]
0821 << " * "<< OpName << "_hidden_gate[i]);\n";
0822 }
0823 out << SP << SP << SP << SP << OpName << "_hidden_gate[i] = "
0824 << fAttrActivationAlpha[direction * 2 + 1] << " * (1. - ex) / (1. + ex);\n";
0825 out << SP << SP << "}\n";
0826 } else if (fAttrActivations[direction * 2 + 1] == "HardSigmoid") {
0827 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0828 if (fType == "float") {
0829 out << SP << SP << SP << "float a = " << fAttrActivationAlpha[direction * 2 + 1] << " * "
0830 << OpName << "_hidden_gate[i] + " << fAttrActivationBeta[direction * 2 + 1] << ";\n";
0831 out << SP << SP << SP << "float b = (a > 0.) ? a : 0.;\n";
0832 }
0833 out << SP << SP << SP << SP << OpName << "_hidden_gate[i] = (b < 1.) ? b : 1.;\n";
0834 out << SP << SP << "}\n";
0835 } else if (fAttrActivations[direction * 2 + 1] == "LeakyRelu") {
0836 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0837 out << SP << SP << SP << "if (" << OpName << "_hidden_gate[i] < 0.)\n";
0838 out << SP << SP << SP << SP << OpName << "_hidden_gate[i] = "
0839 << fAttrActivationAlpha[direction * 2 + 1] << " * " << OpName << "_hidden_gate[i];\n";
0840 out << SP << SP << "}\n";
0841 } else if (fAttrActivations[direction * 2 + 1] == "ThresholdRelu") {
0842 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0843 out << SP << SP << SP << "if (" << OpName << "_hidden_gate[i] < "
0844 << fAttrActivationAlpha[direction * 2 + 1] << ")\n";
0845 out << SP << SP << SP << SP << OpName << "_hidden_gate[i] = 0.;\n";
0846 out << SP << SP << "}";
0847 } else if (fAttrActivations[direction * 2 + 1] == "Elu") {
0848 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0849 out << SP << SP << SP << "if (" << OpName << "_hidden_gate[i] < 0.)\n";
0850 out << SP << SP << SP << SP << OpName << "_hidden_gate[i] = "
0851 << fAttrActivationAlpha[direction * 2 + 1] << " * exp(" << OpName << "_hidden_gate[i] - 1.);\n";
0852 out << SP << SP << "}\n";
0853 } else if (fAttrActivations[direction * 2 + 1] == "Softsign") {
0854 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0855 out << SP << SP << SP << SP << OpName << "_hidden_gate[i] = " << OpName
0856 << "_hidden_gate[i] / (1. + abs(" << OpName << "_hidden_gate[i]));\n";
0857 out << SP << SP << "}\n";
0858 } else {
0859 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0860 out << SP << SP << SP << SP << OpName << "_hidden_gate[i] = log(1. + exp("
0861 << OpName << "_hidden_gate[i]));\n";
0862 out << SP << SP << "}\n";
0863 }
0864
0865
0866 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0867 out << SP << SP << SP << OpName << "_hidden_state[i] = ( 1. - " << OpName
0868 << "_update_gate[i]) * " << OpName << "_hidden_gate[i];\n";
0869 out << SP << SP << "}\n";
0870
0871 out << SP << SP << "if (seq == 0) {\n";
0872 if (!fNInitial_h.empty()) {
0873
0874 out << SP << SP << SP << "for (size_t i = 0; i < " << size << "; i++) {\n";
0875 out << SP << SP << SP << SP << OpName << "_hidden_state[i + offset] += " << OpName
0876 << "_update_gate[i + offset] * " << OpName << "_initial_hidden_state[i];\n";
0877 out << SP << SP << SP << "}\n";
0878 }
0879 out << SP << SP << "} else {\n";
0880
0881 if (direction == 0) {
0882 if (fAttrDirection == "backward") {
0883 out << SP << SP << SP << "size_t previous_offset = (index + 1) * "
0884 << num_directions * batch_size * fAttrHiddenSize << ";\n";
0885 } else {
0886 out << SP << SP << SP << "size_t previous_offset = (seq - 1) * "
0887 << num_directions * batch_size * fAttrHiddenSize << ";\n";
0888 }
0889 } else {
0890 out << SP << SP << SP << "size_t previous_offset = (index + 1) * "
0891 << num_directions * batch_size * fAttrHiddenSize << " + " << batch_size * fAttrHiddenSize << ";\n";
0892 }
0893 out << SP << SP << SP << "for (size_t i = 0; i < " << size << "; i++) {\n";
0894 out << SP << SP << SP << SP << OpName << "_hidden_state[i + offset] += " << OpName
0895 << "_update_gate[i + offset] * " << OpName << "_hidden_state[i + previous_offset];\n";
0896 out << SP << SP << SP << "}\n";
0897 out << SP << SP << "}\n";
0898
0899 out << SP << "}\n";
0900 }
0901
0902
0903 if (!fNSequence_lens.empty()) {
0904 out << SP << "for (size_t seq = 0; seq < " << seq_length << "; seq++) {\n";
0905 out << SP << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
0906 out << SP << SP << SP << "if (seq >= tensor_" << fNSequence_lens << "[batch]) {\n";
0907 for (size_t direction = 0; direction < num_directions; direction++) {
0908 out << SP << SP << SP << SP << SP << "for (size_t h = 0; h < " << fAttrHiddenSize << "; h++) {\n";
0909 out << SP << SP << SP << SP << SP << SP << OpName << "_hidden_state[seq * "
0910 << num_directions * batch_size * fAttrHiddenSize + direction * batch_size * fAttrHiddenSize
0911 << " + batch * " << fAttrHiddenSize << " + h] = 0.;\n";
0912 out << SP << SP << SP << SP << SP << "}\n";
0913 }
0914 out << SP << SP << SP << "}\n";
0915 out << SP << SP << "}\n";
0916 out << SP << "}\n";
0917 }
0918
0919
0920 if (fAttrLayout == 0) {
0921 if (!fNY_h.empty()) {
0922
0923 if (fNSequence_lens.empty()) {
0924 size_t yh_size = batch_size * fAttrHiddenSize;
0925 if (fAttrDirection == "backward") {
0926 out << SP << "std::copy(" << OpName << "_hidden_state, " << OpName << "_hidden_state + "
0927 << yh_size << ", tensor_" << fNY_h << ");\n";
0928 } else {
0929 size_t offset = (seq_length - 1) * num_directions * batch_size * fAttrHiddenSize;
0930 out << SP << "std::copy(" << OpName << "_hidden_state + " << offset << ", " << OpName
0931 << "_hidden_state + " << offset << " + " << yh_size << ", tensor_" << fNY_h << ");\n";
0932 }
0933 if (num_directions == 2) {
0934 out << SP << "std::copy(" << OpName << "_hidden_state + " << yh_size << ", " << OpName
0935 << "_hidden_state + " << 2 * yh_size << ", tensor_" << fNY_h << " + " << yh_size << ");\n";
0936 }
0937 } else {
0938 if (fAttrDirection == "backward") {
0939 out << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
0940 out << SP << SP << "size_t offset = batch * " << fAttrHiddenSize << ";\n";
0941 out << SP << SP << "std::copy(" << OpName << "_hidden_state + offset, " << OpName
0942 << "_hidden_state + offset + " << fAttrHiddenSize << ", tensor_" << fNY_h << " + offset);\n";
0943 out << SP << "}\n";
0944 } else {
0945 out << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
0946 out << SP << SP << "size_t seq = " << "tensor_" << fNSequence_lens << "[batch] - 1;\n";
0947 out << SP << SP << "size_t offset = seq * " << num_directions * batch_size * fAttrHiddenSize
0948 << " + batch * " << fAttrHiddenSize << ";\n";
0949 out << SP << SP << "size_t yh_offset = batch * " << fAttrHiddenSize << ";\n";
0950 out << SP << SP << "std::copy(" << OpName << "_hidden_state + offset, " << OpName
0951 << "_hidden_state + offset + " << fAttrHiddenSize << ", tensor_" << fNY_h << " + yh_offset);\n";
0952 out << SP << "}\n";
0953 }
0954 if (num_directions == 2) {
0955 out << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
0956 out << SP << SP << "size_t offset = " << batch_size * fAttrHiddenSize
0957 << " + batch * " << fAttrHiddenSize << ";\n";
0958 out << SP << SP << "size_t yh_offset = " << batch_size * fAttrHiddenSize
0959 << " + batch * " << fAttrHiddenSize << ";\n";
0960 out << SP << SP << "std::copy(" << OpName << "_hidden_state + offset, " << OpName
0961 << "_hidden_state + offset + " << fAttrHiddenSize << ", tensor_" << fNY_h << " + yh_offset);\n";
0962 out << SP << "}\n";
0963 }
0964 }
0965 }
0966 } else {
0967 if (!fNY.empty()) {
0968
0969 for (size_t direction = 0; direction < num_directions; direction++) {
0970 out << SP << "for (size_t seq = 0; seq < " << seq_length << "; seq++) {\n";
0971 out << SP << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
0972 out << SP << SP << SP << "size_t offset = seq * " << num_directions * batch_size * fAttrHiddenSize
0973 << " + " << direction * batch_size * fAttrHiddenSize << " + batch * " << fAttrHiddenSize << ";\n";
0974 out << SP << SP << SP << "size_t y_offset = batch * " << seq_length * num_directions * fAttrHiddenSize
0975 << " + seq * " << num_directions * fAttrHiddenSize << " + " << direction * fAttrHiddenSize << ";\n";
0976 out << SP << SP << SP << "std::copy(" << OpName << "_hidden_state + offset, " << OpName
0977 << "_hidden_state + offset + " << fAttrHiddenSize << ", tensor_" << fNY << " + y_offset);\n";
0978 out << SP << SP << "}\n";
0979 out << SP << "}\n";
0980 }
0981 }
0982 if (!fNY_h.empty()) {
0983
0984 if (fAttrDirection == "backward") {
0985 out << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
0986 out << SP << SP << "size_t offset = batch * " << fAttrHiddenSize << ";\n";
0987 out << SP << SP << "size_t yh_offset = batch * " << num_directions * fAttrHiddenSize << ";\n";
0988 out << SP << SP << "std::copy(" << OpName << "_hidden_state + offset, " << OpName
0989 << "_hidden_state + offset + " << fAttrHiddenSize << ", tensor_" << fNY_h << " + yh_offset);\n";
0990 out << SP << "}\n";
0991 } else {
0992 out << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
0993 if (fNSequence_lens.empty()) {
0994 out << SP << SP << "size_t seq = " << seq_length - 1 << ";\n";
0995 } else {
0996 out << SP << SP << "size_t seq = " << "tensor_" << fNSequence_lens << "[batch] - 1;\n";
0997 }
0998 out << SP << SP << "size_t offset = seq * " << num_directions * batch_size * fAttrHiddenSize
0999 << " + batch * " << fAttrHiddenSize << ";\n";
1000 out << SP << SP << "size_t yh_offset = batch * " << num_directions * fAttrHiddenSize << ";\n";
1001 out << SP << SP << "std::copy(" << OpName << "_hidden_state + offset, " << OpName
1002 << "_hidden_state + offset + " << fAttrHiddenSize << ", tensor_" << fNY_h << " + yh_offset);\n";
1003 out << SP << "}\n";
1004 }
1005 if (num_directions == 2) {
1006 out << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
1007 out << SP << SP << "size_t offset = " << batch_size * fAttrHiddenSize << " + batch * "
1008 << fAttrHiddenSize << ";\n";
1009 out << SP << SP << "size_t yh_offset = batch * " << num_directions * fAttrHiddenSize << " + "
1010 << fAttrHiddenSize << ";\n";
1011 out << SP << SP << "std::copy(" << OpName << "_hidden_state + offset, " << OpName
1012 << "_hidden_state + offset + " << fAttrHiddenSize << ", tensor_" << fNY_h << " + yh_offset);\n";
1013 out << SP << "}\n";
1014 }
1015 }
1016 }
1017
1018 return out.str();
1019 }
1020
1021 }
1022 }
1023 }
1024
1025 #endif