File indexing completed on 2025-01-18 10:11:06
0001 #ifndef TMVA_SOFIE_ROPERATOR_GRU_I
0002 #define TMVA_SOFIE_ROPERATOR_GRU_I
0003
0004 namespace TMVA {
0005 namespace Experimental {
0006 namespace SOFIE {
0007
0008 template <typename T>
0009 auto ROperator_GRU<T>::TypeInference(std::vector<ETensorType> input)
0010 -> std::vector<ETensorType> {
0011 ETensorType out = input[0];
0012 return {out, out};
0013 }
0014
0015 template<typename T>
0016 auto ROperator_GRU<T>::ShapeInference(std::vector<std::vector<size_t>> input)
0017 -> std::vector<std::vector<size_t>> {
0018 size_t num_directions = input[1][0];
0019 size_t hidden_size = input[1][1] / 3;
0020 if (fAttrLayout == 0) {
0021 size_t seq_length = input[0][0];
0022 size_t batch_size = input[0][1];
0023 std::vector<std::vector<size_t>> ret(
0024 {{seq_length, num_directions, batch_size, hidden_size},
0025 {num_directions, batch_size, hidden_size}});
0026 return ret;
0027 } else {
0028 size_t batch_size = input[0][0];
0029 size_t seq_length = input[0][1];
0030 std::vector<std::vector<size_t>> ret(
0031 {{batch_size, seq_length, num_directions, hidden_size},
0032 {batch_size, num_directions, hidden_size}});
0033 return ret;
0034 }
0035 }
0036
0037 template<typename T>
0038 void ROperator_GRU<T>::Initialize(RModel &model) {
0039 fUseSession = model.UseSession();
0040
0041 if (!model.CheckIfTensorAlreadyExist(fNX)) {
0042 throw std::runtime_error("TMVA SOFIE GRU Op input tensor " + fNX + " is not found in model.");
0043 }
0044 fShapeX = model.GetTensorShape(fNX);
0045 if (fShapeX.size() != 3) {
0046 throw std::runtime_error("TMVA SOFIE GRU Op input tensor " + fNX + " is not of 3 dimensions.");
0047 }
0048 if (!model.CheckIfTensorAlreadyExist(fNW)) {
0049 throw std::runtime_error("TMVA SOFIE GRU Op input tensor " + fNW + " is not found in model.");
0050 }
0051 fShapeW = model.GetTensorShape(fNW);
0052 if (fShapeW.size() != 3) {
0053 throw std::runtime_error("TMVA SOFIE GRU Op input tensor " + fNW + " is not of 3 dimensions.");
0054 }
0055 if (!model.CheckIfTensorAlreadyExist(fNR)) {
0056 throw std::runtime_error("TMVA SOFIE GRU Op input tensor " + fNR + " is not found in model.");
0057 }
0058 fShapeR = model.GetTensorShape(fNR);
0059 if (fShapeR.size() != 3) {
0060 throw std::runtime_error("TMVA SOFIE GRU Op input tensor " + fNR + " is not of 3 dimensions.");
0061 }
0062 if (!fNB.empty()) {
0063 if (!model.CheckIfTensorAlreadyExist(fNB)) {
0064 throw std::runtime_error("TMVA SOFIE GRU op input tensor " + fNB + " is not found in model.");
0065 }
0066 fShapeB = model.GetTensorShape(fNB);
0067 if (fShapeB.size() != 2 && fShapeB.size() != 4) {
0068 throw std::runtime_error("TMVA SOFIE GRU op input tensor " + fNB + " is not of 2 or 4 dimensions.");
0069 }
0070 if (fShapeB.size() == 2) {
0071
0072 auto original_data = model.GetInitializedTensorData(fNB);
0073 size_t num_directions = fShapeW[0];
0074 size_t batch_size = (fAttrLayout == 0) ? fShapeX[1] : fShapeX[0];
0075 size_t seq_length = (fAttrLayout == 0) ? fShapeX[0] : fShapeX[1];
0076 if (fType == "float") {
0077 float *original_bias = static_cast<float*>(original_data.get());
0078 float *new_bias = new float[num_directions * 6 * seq_length * batch_size * fAttrHiddenSize];
0079 for (size_t direction = 0; direction < num_directions; direction++) {
0080 for (size_t i = 0; i < 6; i++) {
0081 for (size_t seq = 0; seq < seq_length; seq++) {
0082 for (size_t batch = 0; batch < batch_size; batch++) {
0083 size_t bias_offset = direction * 6 * fAttrHiddenSize + i * fAttrHiddenSize;
0084 size_t offset = direction * 6 * batch_size * seq_length * fAttrHiddenSize +
0085 i * batch_size * seq_length * fAttrHiddenSize +
0086 + seq *batch_size *fAttrHiddenSize + batch *fAttrHiddenSize;
0087 std::copy(original_bias + bias_offset, original_bias + bias_offset + fAttrHiddenSize,
0088 new_bias + offset);
0089 }
0090 }
0091 }
0092 }
0093
0094 std::vector<size_t> new_bias_shape = {num_directions, 6, seq_length, batch_size, fAttrHiddenSize};
0095 std::shared_ptr<void> new_bias_ptr(new_bias, std::default_delete<float[]>());
0096 model.UpdateInitializedTensor(fNB, model.GetTensorType(fNB), new_bias_shape, new_bias_ptr);
0097 fShapeB = model.GetTensorShape(fNB);
0098 }
0099 }
0100 }
0101 if (!fNSequence_lens.empty()) {
0102 if (!model.CheckIfTensorAlreadyExist(fNSequence_lens)) {
0103 throw std::runtime_error("TMVA SOFIE GRU Op input tensor " +
0104 fNSequence_lens +
0105 "is not found in model.");
0106 }
0107 fShapeSequence_lens = model.GetTensorShape(fNSequence_lens);
0108 if (fShapeSequence_lens.size() != 1) {
0109 throw std::runtime_error("TMVA SOFIE GRU Op input tensor " +
0110 fNSequence_lens +
0111 " is not of 1 dimension.");
0112 }
0113 }
0114 if (!fNInitial_h.empty()) {
0115 if (!model.CheckIfTensorAlreadyExist(fNInitial_h)) {
0116 throw std::runtime_error("TMVA SOFIE GRU Op input tensor " +
0117 fNInitial_h + " is not found in model.");
0118 }
0119 fShapeInitial_h = model.GetTensorShape(fNInitial_h);
0120 if (fShapeInitial_h.size() != 3) {
0121 throw std::runtime_error("TMVA SOFIE GRU Op input tensor " +
0122 fNInitial_h + " is not of 3 dimensions.");
0123 }
0124 }
0125 if (!fNY.empty()) {
0126 fShapeY = ShapeInference({fShapeX, fShapeW})[0];
0127 if (!model.CheckIfTensorAlreadyExist(fNY)) {
0128 model.AddIntermediateTensor(fNY, model.GetTensorType(fNX), fShapeY);
0129 }
0130 }
0131 if (!fNY_h.empty()) {
0132 fShapeY_h = ShapeInference({fShapeX, fShapeW})[1];
0133 if (!model.CheckIfTensorAlreadyExist(fNY_h)) {
0134 model.AddIntermediateTensor(fNY_h, model.GetTensorType(fNX), fShapeY_h);
0135 }
0136 }
0137
0138 for (auto &activation : fAttrActivations) {
0139 if (activation != "Relu" && activation != "Tanh" &&
0140 activation != "Sigmoid" && activation != "Affine" &&
0141 activation != "LeakyRelu" && activation != "ThresholdRelu" &&
0142 activation != "ScaledTanh" && activation != "HardSigmoid" &&
0143 activation != "Elu" && activation != "Softsign" &&
0144 activation != "Softplus") {
0145 throw std::runtime_error("TMVA SOFIE - Activation function " +
0146 activation + " not implemented");
0147 }
0148 }
0149 if (fAttrDirection == "reverse") fAttrDirection = "backward";
0150 if (fAttrDirection != "forward" && fAttrDirection != "backward" &&
0151 fAttrDirection != "reverse" &&
0152 fAttrDirection != "bidirectional") {
0153 throw std::runtime_error(
0154 "TMVA SOFIE - Invalid GRU direction fAttrDirection = " +
0155 fAttrDirection);
0156 }
0157 if (3 * fAttrHiddenSize != fShapeW[1]) {
0158 throw std::runtime_error(
0159 "TMVA SOFIE - fAttrHiddenSize must be equal to " +
0160 std::to_string(fShapeW[1] / 3));
0161 }
0162 if (fAttrLayout > 1) {
0163 throw std::runtime_error("TMVA SOFIE - Layout fAttrLayout = " +
0164 std::to_string(fAttrLayout) +
0165 " must be 0 (timewise) or 1 (batchwise)");
0166 }
0167 if (fAttrLinearBeforeReset > 1) {
0168 throw std::runtime_error(
0169 "TMVA SOFIE - fAttrInputForget = " + std::to_string(fAttrLinearBeforeReset)
0170 + " must be 0 or 1.");
0171 }
0172 if (fAttrActivations.empty()) {
0173 if (fAttrDirection == "bidirectional") {
0174 fAttrActivations = {"Sigmoid", "Tanh", "Sigmoid", "Tanh"};
0175 } else {
0176 fAttrActivations = {"Sigmoid", "Tanh"};
0177 }
0178 }
0179 }
0180
0181
0182 template <typename T>
0183 std::string ROperator_GRU<T>::GenerateSessionMembersCode(std::string opName)
0184 {
0185 opName = "op_" + opName;
0186 std::stringstream out;
0187
0188 size_t num_directions = fShapeW[0];
0189 size_t seq_length = (fAttrLayout == 0) ? fShapeX[0] : fShapeX[1];
0190 size_t batch_size = (fAttrLayout == 0) ? fShapeX[1] : fShapeX[0];
0191 size_t input_size = fShapeX[2];
0192
0193 if (fAttrLayout != 0) {
0194 out << "std::vector<" << fType << "> fVec_" << opName << "_input = std::vector<" << fType << ">("
0195 << seq_length * batch_size * input_size << ");\n";
0196 out << "std::vector<" << fType << "> fVec_" << opName << "_initial_hidden_state = std::vector<" << fType << ">("
0197 << num_directions * batch_size * fAttrHiddenSize << ");\n";
0198 out << "std::vector<" << fType << "> fVec_" << opName << "_initial_cell_state = std::vector<" << fType << ">("
0199 << num_directions * batch_size * fAttrHiddenSize << ");\n";
0200 }
0201
0202 size_t ff_size = seq_length * batch_size * fAttrHiddenSize;
0203 out << "std::vector<" << fType << "> fVec_" << opName << "_f_update_gate = std::vector<" << fType << ">(" << ff_size << ");\n";
0204 out << "std::vector<" << fType << "> fVec_" << opName << "_f_reset_gate = std::vector<" << fType << ">(" << ff_size << ");\n";
0205 out << "std::vector<" << fType << "> fVec_" << opName << "_f_hidden_gate = std::vector<" << fType << ">(" << ff_size << ");\n";
0206
0207 size_t hs_size = seq_length * num_directions * batch_size * fAttrHiddenSize;
0208 out << "std::vector<" << fType << "> fVec_" << opName << "_update_gate = std::vector<" << fType << ">(" << hs_size << ");\n";
0209 out << "std::vector<" << fType << "> fVec_" << opName << "_reset_gate = std::vector<" << fType << ">(" << hs_size << ");\n";
0210 out << "std::vector<" << fType << "> fVec_" << opName << "_hidden_gate = std::vector<" << fType << ">(" << hs_size << ");\n";
0211
0212
0213 out << "std::vector<" << fType << "> fVec_" << opName << "_feedback = std::vector<" << fType << ">("
0214 << batch_size * fAttrHiddenSize << ");\n";
0215
0216
0217 if (fAttrLayout != 0 || fNY.empty()) {
0218 out << "std::vector<" << fType << "> fVec_" << opName << "_hidden_state = std::vector<" << fType << ">(" << hs_size << ");\n";
0219 }
0220
0221 out << "\n";
0222
0223 return out.str();
0224 }
0225
0226
0227 template<typename T>
0228 auto ROperator_GRU<T>::Generate(std::string OpName)
0229 -> std::string {
0230 OpName = "op_" + OpName;
0231 std::stringstream out;
0232
0233 size_t seq_length = (fAttrLayout == 0) ? fShapeX[0] : fShapeX[1];
0234 size_t batch_size = (fAttrLayout == 0) ? fShapeX[1] : fShapeX[0];
0235 size_t input_size = fShapeX[2];
0236 size_t num_directions = fShapeW[0];
0237
0238
0239 if (fAttrLayout == 0) {
0240 out << SP << fType << " *" << OpName << "_input = tensor_" << fNX << ";\n";
0241 } else {
0242 if (fUseSession) {
0243 out << SP << fType << " * " << OpName << "_input = fVec_" << OpName << "_input.data();\n";
0244 } else {
0245 out << SP << fType << " " << OpName << "_input[" << seq_length * batch_size * input_size << "];\n";
0246 }
0247 out << SP << "for(size_t seq = 0; seq < " << seq_length << "; seq++) {\n";
0248 out << SP << SP << "for(size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
0249 out << SP << SP << SP << "for(size_t i = 0; i < " << input_size << "; i++) {\n";
0250 out << SP << SP << SP << SP << OpName << "_input[seq * " << batch_size * input_size
0251 << " + batch * " << input_size << " + i] = " << "tensor_" << fNX << "[batch * "
0252 << seq_length * input_size << " + seq * " << input_size << " + i];\n";
0253 out << SP << SP << SP << "}\n";
0254 out << SP << SP << "}\n";
0255 out << SP << "}\n";
0256 }
0257
0258
0259 if (!fNInitial_h.empty()) {
0260 if (fAttrLayout == 0) {
0261 out << SP << fType << " *" << OpName << "_initial_hidden_state = " << " tensor_"
0262 << fNInitial_h << ";\n";
0263 } else {
0264 if (fUseSession) {
0265 out << SP << fType << " * " << OpName << "_initial_hidden_state = fVec_" << OpName
0266 << "_initial_hidden_state.data();\n";
0267 } else {
0268 out << SP << fType << " " << OpName << "_initial_hidden_state[" << num_directions * batch_size *
0269 fAttrHiddenSize << "];\n";
0270 }
0271 for (size_t direction = 0; direction < num_directions; direction++) {
0272 out << SP << "for(size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
0273 out << SP << SP << "for(size_t h = 0; h < " << fAttrHiddenSize << "; h++) {\n";
0274 out << SP << SP << SP << OpName << "_initial_hidden_state["
0275 << direction * batch_size * fAttrHiddenSize << " + batch * " << fAttrHiddenSize
0276 << " + h] = tensor_" << fNInitial_h << "[batch * " << num_directions * fAttrHiddenSize
0277 << " + " << direction * fAttrHiddenSize << " + h];\n";
0278 out << SP << SP << "}\n";
0279 out << SP << "}\n";
0280 }
0281 }
0282 }
0283
0284
0285 size_t feedforward_size = seq_length * batch_size * fAttrHiddenSize;
0286 if (fUseSession) {
0287 out << SP << fType << " * " << OpName << "_f_update_gate = fVec_" << OpName << "_f_update_gate.data();\n";
0288 out << SP << fType << " * " << OpName << "_f_reset_gate = fVec_" << OpName << "_f_reset_gate.data();\n";
0289 out << SP << fType << " * " << OpName << "_f_hidden_gate = fVec_" << OpName << "_f_hidden_gate.data();\n";
0290 } else {
0291 out << SP << fType << " " << OpName << "_f_update_gate[" << feedforward_size << "] = {0};\n";
0292 out << SP << fType << " " << OpName << "_f_reset_gate[" << feedforward_size << "] = {0};\n";
0293 out << SP << fType << " " << OpName << "_f_hidden_gate[" << feedforward_size << "] = {0};\n";
0294 }
0295
0296 size_t hidden_state_size = seq_length * num_directions * batch_size * fAttrHiddenSize;
0297 if (fUseSession) {
0298 out << SP << fType << " * " << OpName << "_update_gate = fVec_" << OpName << "_update_gate.data();\n";
0299 out << SP << fType << " * " << OpName << "_reset_gate = fVec_" << OpName << "_reset_gate.data();\n";
0300 out << SP << fType << " * " << OpName << "_hidden_gate = fVec_" << OpName << "_hidden_gate.data();\n";
0301 } else {
0302 out << SP << fType << " " << OpName << "_update_gate[" << hidden_state_size << "] = {0};\n";
0303 out << SP << fType << " " << OpName << "_reset_gate[" << hidden_state_size << "] = {0};\n";
0304 out << SP << fType << " " << OpName << "_hidden_gate[" << hidden_state_size << "] = {0};\n";
0305 }
0306
0307 if (fAttrLayout == 0 && !fNY.empty()) {
0308 out << SP << fType << " *" << OpName << "_hidden_state = tensor_" << fNY << ";\n";
0309 } else {
0310 if (fUseSession) {
0311 out << SP << fType << " * " << OpName << "_hidden_state = fVec_" << OpName << "_hidden_state.data();\n";
0312 } else {
0313 out << SP << fType << " " << OpName << "_hidden_state[" << hidden_state_size << "] = {0};\n";
0314 }
0315 }
0316
0317 if (fUseSession) {
0318 out << SP << fType << " * " << OpName << "_feedback = fVec_" << OpName << "_feedback.data();\n";
0319 } else {
0320 out << SP << fType << " " << OpName << "_feedback[" << batch_size * fAttrHiddenSize << "] = {0};\n";
0321 }
0322
0323 out << SP << "char " << OpName << "_transA = 'N';\n";
0324 out << SP << "char " << OpName << "_transB = 'T';\n";
0325 out << SP << "int " << OpName << "_m = " << seq_length * batch_size << ";\n";
0326 out << SP << "int " << OpName << "_m2 = " << batch_size << ";\n";
0327 out << SP << "int " << OpName << "_n = " << fAttrHiddenSize << ";\n";
0328 out << SP << "int " << OpName << "_k = " << input_size << ";\n";
0329 if (fType == "float") {
0330 out << SP << "float " << OpName << "_alpha = 1.;\n";
0331 out << SP << "float " << OpName << "_beta = 0.;\n";
0332 }
0333 if (!fNB.empty()) {
0334 out << SP << "int " << OpName << "_bias_size = " << seq_length * batch_size * fAttrHiddenSize << ";\n";
0335 }
0336 out << SP << "int " << OpName << "_incx = 1;\n";
0337 out << SP << "int " << OpName << "_incy = 1;\n";
0338 out << SP << "int " << OpName << "_feedback_size = " << batch_size * fAttrHiddenSize << ";\n";
0339
0340 for (size_t direction = 0; direction < num_directions; direction++) {
0341 if (direction == 0) {
0342 if (fType == "float") {
0343
0344 out << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
0345 << OpName <<"_n, &" << OpName << "_m, &" << OpName << "_k, &" << OpName << "_alpha, tensor_"
0346 << fNW << ", &" << OpName << "_k, " << OpName << "_input, &" << OpName << "_k, &"
0347 << OpName << "_beta, " << OpName << "_f_update_gate, &" << OpName << "_n);\n";
0348
0349 size_t wr_offset = fAttrHiddenSize * input_size;
0350 out << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
0351 << OpName <<"_n, &" << OpName << "_m, &" << OpName << "_k, &" << OpName << "_alpha, tensor_"
0352 << fNW << " + " << wr_offset << ", &" << OpName << "_k, " << OpName << "_input, &"
0353 << OpName << "_k, &" << OpName << "_beta, " << OpName << "_f_reset_gate, &" << OpName << "_n);\n";
0354
0355 size_t wh_offset = 2 * fAttrHiddenSize * input_size;
0356 out << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
0357 << OpName <<"_n, &" << OpName << "_m, &" << OpName << "_k, &" << OpName << "_alpha, tensor_"
0358 << fNW << " + " << wh_offset << ", &" << OpName << "_k, " << OpName << "_input, &"
0359 << OpName << "_k, &" << OpName << "_beta, " << OpName << "_f_hidden_gate, &" << OpName << "_n);\n";
0360 }
0361 } else {
0362 if (fType == "float") {
0363
0364 size_t wz_offset = 3 * fAttrHiddenSize * input_size;
0365 out << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
0366 << OpName <<"_n, &" << OpName << "_m, &" << OpName << "_k, &" << OpName << "_alpha, tensor_"
0367 << fNW << " + " << wz_offset << ", &" << OpName << "_k, " << OpName << "_input, &"
0368 << OpName << "_k, &" << OpName << "_beta, " << OpName << "_f_update_gate, &" << OpName << "_n);\n";
0369
0370 size_t wr_offset = 3 * fAttrHiddenSize * input_size + fAttrHiddenSize * input_size;
0371 out << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
0372 << OpName <<"_n, &" << OpName << "_m, &" << OpName << "_k, &" << OpName << "_alpha, tensor_"
0373 << fNW << " + " << wr_offset << ", &" << OpName << "_k, " << OpName << "_input, &"
0374 << OpName << "_k, &" << OpName << "_beta, " << OpName << "_f_reset_gate, &" << OpName << "_n);\n";
0375
0376 size_t wh_offset = 3 * fAttrHiddenSize * input_size + 2 * fAttrHiddenSize * input_size;
0377 out << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
0378 << OpName <<"_n, &" << OpName << "_m, &" << OpName << "_k, &" << OpName << "_alpha, tensor_"
0379 << fNW << " + " << wh_offset << ", &" << OpName << "_k, " << OpName << "_input, &"
0380 << OpName << "_k, &" << OpName << "_beta, " << OpName << "_f_hidden_gate, &" << OpName << "_n);\n";
0381 }
0382 }
0383
0384 if (!fNB.empty()) {
0385 if (direction == 0) {
0386 if (fType == "float") {
0387
0388 out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_"
0389 << fNB << ", &" << OpName << "_incx, " << OpName << "_f_update_gate, &" << OpName << "_incy);\n";
0390
0391 size_t rbz_offset = 3 * batch_size * seq_length * fAttrHiddenSize;
0392 out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_"
0393 << fNB << " + " << rbz_offset << ", &" << OpName << "_incx, " << OpName << "_f_update_gate, &"
0394 << OpName << "_incy);\n";
0395
0396 size_t wbr_offset = batch_size * seq_length * fAttrHiddenSize;
0397 out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_"
0398 << fNB << " + " << wbr_offset << ", &" << OpName << "_incx, " << OpName << "_f_reset_gate, &"
0399 << OpName << "_incy);\n";
0400
0401
0402 size_t rbr_offset = 4 * batch_size * seq_length * fAttrHiddenSize;
0403 out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_"
0404 << fNB << " + " << rbr_offset << ", &" << OpName << "_incx, " << OpName << "_f_reset_gate, &"
0405 << OpName << "_incy);\n";
0406
0407 size_t wbh_offset = 2 * batch_size * seq_length * fAttrHiddenSize;
0408 out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_"
0409 << fNB << " + " << wbh_offset << ", &" << OpName << "_incx, " << OpName << "_f_hidden_gate, &"
0410 << OpName << "_incy);\n";
0411 if (fAttrLinearBeforeReset == 0) {
0412
0413 size_t rbh_offset = 5 * batch_size * seq_length * fAttrHiddenSize;
0414 out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_"
0415 << fNB << " + " << rbh_offset << ", &" << OpName << "_incx, " << OpName
0416 << "_f_hidden_gate, &" << OpName << "_incy);\n";
0417 }
0418 }
0419 } else {
0420 if (fType == "float") {
0421
0422 size_t wbz_offset = 6 * batch_size * seq_length * fAttrHiddenSize;
0423 out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_"
0424 << fNB << " + " << wbz_offset << ", &" << OpName << "_incx, " << OpName << "_f_update_gate, &"
0425 << OpName << "_incy);\n";
0426
0427
0428 size_t rbz_offset = 9 * batch_size * seq_length * fAttrHiddenSize;
0429 out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_"
0430 << fNB << " + " << rbz_offset << ", &" << OpName << "_incx, " << OpName << "_f_update_gate, &"
0431 << OpName << "_incy);\n";
0432
0433 size_t wbr_offset = 7 * batch_size * seq_length * fAttrHiddenSize;
0434 out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_"
0435 << fNB << " + " << wbr_offset << ", &" << OpName << "_incx, " << OpName << "_f_reset_gate, &"
0436 << OpName << "_incy);\n";
0437
0438 size_t rbr_offset = 10 * batch_size * seq_length * fAttrHiddenSize;
0439 out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_"
0440 << fNB << " + " << rbr_offset << ", &" << OpName << "_incx, " << OpName << "_f_reset_gate, &"
0441 << OpName << "_incy);\n";
0442
0443 size_t wbh_offset = 8 * batch_size * seq_length * fAttrHiddenSize;
0444 out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_"
0445 << fNB << " + " << wbh_offset << ", &" << OpName << "_incx, " << OpName << "_f_hidden_gate, &"
0446 << OpName << "_incy);\n";
0447 if (fAttrLinearBeforeReset == 0) {
0448
0449 size_t rbh_offset = 11 * batch_size * seq_length * fAttrHiddenSize;
0450 out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_"
0451 << fNB << " + " << rbh_offset << ", &" << OpName << "_incx, " << OpName
0452 << "_f_hidden_gate, &" << OpName << "_incy);\n";
0453 }
0454 }
0455 }
0456 }
0457
0458
0459 out << SP << "for (size_t seq = 0; seq < " << seq_length << "; seq++) {\n";
0460 out << SP << SP << "size_t offset = seq * " << batch_size * fAttrHiddenSize << ";\n";
0461 if (direction == 0) {
0462 out << SP << SP << "size_t gate_offset = seq * " << num_directions * batch_size * fAttrHiddenSize
0463 << ";\n";
0464 } else {
0465 out << SP << SP << "size_t gate_offset = seq * " << num_directions * batch_size * fAttrHiddenSize
0466 << " + " << batch_size * fAttrHiddenSize << ";\n";
0467 }
0468 size_t f_seq_size = batch_size * fAttrHiddenSize;
0469 out << SP << SP << "std::copy(" << OpName << "_f_update_gate + offset, " << OpName
0470 << "_f_update_gate + offset + " << f_seq_size << ", " << OpName << "_update_gate + gate_offset);\n";
0471 out << SP << SP << "std::copy(" << OpName << "_f_reset_gate + offset, " << OpName
0472 << "_f_reset_gate + offset + " << f_seq_size << ", " << OpName << "_reset_gate + gate_offset);\n";
0473 out << SP << SP << "std::copy(" << OpName << "_f_hidden_gate + offset, " << OpName
0474 << "_f_hidden_gate + offset + " << f_seq_size << ", " << OpName << "_hidden_gate + gate_offset);\n";
0475 out << SP << "}\n";
0476
0477 out << SP << "for (size_t seq = 0; seq < " << seq_length << "; seq++) {\n";
0478 if (fAttrDirection == "backward" || direction == 1) {
0479 out << SP << SP << "size_t index = " << seq_length - 1 << " - seq;\n";
0480 } else {
0481 out << SP << SP << "size_t index = seq;\n";
0482 }
0483 out << SP << SP << "int m2 = " << batch_size << ";\n";
0484 if (direction == 0) {
0485 out << SP << SP << "size_t offset = index * " << num_directions * batch_size * fAttrHiddenSize
0486 << ";\n";
0487 } else {
0488 out << SP << SP << "size_t offset = index * " << num_directions * batch_size * fAttrHiddenSize
0489 << " + " << batch_size * fAttrHiddenSize << ";\n";
0490 }
0491 size_t size = batch_size * fAttrHiddenSize;
0492
0493 out << SP << SP << "if (seq == 0) {\n";
0494 if (!fNInitial_h.empty()) {
0495 if (direction == 0) {
0496 if (fType == "float") {
0497 out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
0498 << OpName << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << ", &"
0499 << OpName << "_n, " << OpName << "_initial_hidden_state, &" << OpName << "_n, &" << OpName
0500 << "_alpha, " << OpName << "_update_gate + offset, &" << OpName << "_n);\n";
0501 size_t rr_offset = fAttrHiddenSize * fAttrHiddenSize;
0502 out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
0503 << OpName << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << " + "
0504 << rr_offset << ", &" << OpName << "_n, " << OpName << "_initial_hidden_state, &" << OpName
0505 << "_n, &" << OpName << "_alpha, " << OpName << "_reset_gate + offset, &" << OpName << "_n);\n";
0506 }
0507 } else {
0508 if (fType == "float") {
0509 size_t rz_offset = 3 * fAttrHiddenSize * fAttrHiddenSize;
0510 out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
0511 << OpName << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << " + "
0512 << rz_offset << ", &" << OpName << "_n, " << OpName << "_initial_hidden_state, &" << OpName
0513 << "_n, &" << OpName << "_alpha, " << OpName << "_update_gate + offset, &" << OpName << "_n);\n";
0514 size_t rr_offset = 4 * fAttrHiddenSize * fAttrHiddenSize;
0515 out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
0516 << OpName << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << " + "
0517 << rr_offset << ", &" << OpName << "_n, " << OpName << "_initial_hidden_state, &" << OpName
0518 << "_n, &" << OpName << "_alpha, " << OpName << "_reset_gate + offset, &" << OpName << "_n);\n";
0519 }
0520 }
0521 }
0522 out << SP << SP << "} else {\n";
0523
0524 if (direction == 0) {
0525 if (fAttrDirection == "backward") {
0526 out << SP << SP << SP << "size_t previous_offset = (index + 1) * "
0527 << num_directions * batch_size * fAttrHiddenSize << ";\n";
0528 } else {
0529 out << SP << SP << SP << "size_t previous_offset = (seq - 1) * "
0530 << num_directions * batch_size * fAttrHiddenSize << ";\n";
0531 }
0532 if (fType == "float") {
0533 out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
0534 << OpName << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << ", &"
0535 << OpName << "_n, " << OpName << "_hidden_state + previous_offset, &" << OpName << "_n, &"
0536 << OpName << "_alpha, " << OpName << "_update_gate + offset, &" << OpName << "_n);\n";
0537 size_t rr_offset = fAttrHiddenSize * fAttrHiddenSize;
0538 out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
0539 << OpName << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << " + "
0540 << rr_offset << ", &" << OpName << "_n, " << OpName << "_hidden_state + previous_offset, &"
0541 << OpName << "_n, &" << OpName << "_alpha, " << OpName << "_reset_gate + offset, &"
0542 << OpName << "_n);\n";
0543 }
0544 } else {
0545 out << SP << SP << SP << "size_t previous_offset = (index + 1) * "
0546 << num_directions * batch_size * fAttrHiddenSize << " + " << batch_size * fAttrHiddenSize << ";\n";
0547 if (fType == "float") {
0548 size_t rz_offset = 3 * fAttrHiddenSize * fAttrHiddenSize;
0549 out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
0550 << OpName << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << " + "
0551 << rz_offset << ", &" << OpName << "_n, " << OpName << "_hidden_state + previous_offset, &"
0552 << OpName << "_n, &" << OpName << "_alpha, " << OpName << "_update_gate + offset, &"
0553 << OpName << "_n);\n";
0554 size_t rr_offset = 4 * fAttrHiddenSize * fAttrHiddenSize;
0555 out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
0556 << OpName << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << " + "
0557 << rr_offset << ", &" << OpName << "_n, " << OpName << "_hidden_state + previous_offset, &"
0558 << OpName << "_n, &" << OpName << "_alpha, " << OpName << "_reset_gate + offset, &"
0559 << OpName << "_n);\n";
0560 }
0561 }
0562 out << SP << SP << "}\n";
0563
0564
0565 if (fAttrClip > .0) {
0566 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0567 if (fType == "float") {
0568 out << SP << SP << SP << "float z = (" << OpName << "_update_gate[i] > " << -fAttrClip
0569 << ") ? " << OpName << "_update_gate[i] : " << -fAttrClip << ";\n";
0570 }
0571 out << SP << SP << SP << OpName << "_update_gate[i] = (z < " << fAttrClip
0572 << ") ? z : " << fAttrClip << ";\n";
0573 if (fType == "float") {
0574 out << SP << SP << SP << "float r = (" << OpName << "_reset_gate[i] > " << -fAttrClip
0575 << ") ? " << OpName << "_reset_gate[i] : " << -fAttrClip << ";\n";
0576 }
0577 out << SP << SP << SP << OpName << "_reset_gate[i] = (r < " << fAttrClip
0578 << ") ? r : " << fAttrClip << ";\n";
0579 out << SP << SP << "}\n";
0580 }
0581
0582
0583 if (fAttrActivations[direction * 2] == "Relu") {
0584 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0585 out << SP << SP << SP << "if (" << OpName << "_update_gate[i] < 0.)\n";
0586 out << SP << SP << SP << SP << OpName << "_update_gate[i] = 0.;\n";
0587 out << SP << SP << SP << "if (" << OpName << "_reset_gate[i] < 0.)\n";
0588 out << SP << SP << SP << SP << OpName << "_reset_gate[i] = 0.;\n";
0589 out << SP << SP << "}\n";
0590 } else if (fAttrActivations[direction * 2] == "Tanh") {
0591 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0592 if (fType == "float") {
0593 out << SP << SP << SP << "float z = exp(-2 * " << OpName << "_update_gate[i]);\n";
0594 }
0595 out << SP << SP << SP << SP << OpName << "_update_gate[i] = (1. - z) / (1. + z);\n";
0596 if (fType == "float") {
0597 out << SP << SP << SP << "float r = exp(-2 * " << OpName << "_reset_gate[i]);\n";
0598 }
0599 out << SP << SP << SP << SP << OpName << "_reset_gate[i] = (1. - r) / (1. + r);\n";
0600 out << SP << SP << "}\n";
0601 } else if (fAttrActivations[direction * 2] == "Sigmoid") {
0602 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0603 out << SP << SP << SP << SP << OpName << "_update_gate[i] = 1. / (1. + exp(-"
0604 << OpName << "_update_gate[i]));\n";
0605 out << SP << SP << SP << SP << OpName << "_reset_gate[i] = 1. / (1. + exp(-"
0606 << OpName << "_reset_gate[i]));\n";
0607 out << SP << SP << "}\n";
0608 } else if (fAttrActivations[direction * 2] == "Affine") {
0609 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0610 out << SP << SP << SP << SP << OpName << "_update_gate[i] = "
0611 << fAttrActivationAlpha[direction * 2] << " * " << OpName << "_update_gate[i] + "
0612 << fAttrActivationBeta[direction * 2] << ";\n";
0613 out << SP << SP << SP << SP << OpName << "_reset_gate[i] = "
0614 << fAttrActivationAlpha[direction * 2] << " * " << OpName << "_reset_gate[i] + "
0615 << fAttrActivationBeta[direction * 2] << ";\n";
0616 out << SP << SP << "}\n";
0617 } else if (fAttrActivations[direction * 2] == "ScaledTanh") {
0618 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0619 if (fType == "float") {
0620 out << SP << SP << SP << "float z = exp(-2 * " << fAttrActivationBeta[direction * 2]
0621 << " * "<< OpName << "_update_gate[i]);\n";
0622 }
0623 out << SP << SP << SP << SP << OpName << "_update_gate[i] = "
0624 << fAttrActivationAlpha[direction * 2] << " * (1. - z) / (1. + z);\n";
0625 if (fType == "float") {
0626 out << SP << SP << SP << "float r = exp(-2 * " << fAttrActivationBeta[direction * 2]
0627 << " * "<< OpName << "_reset_gate[i]);\n";
0628 }
0629 out << SP << SP << SP << SP << OpName << "_reset_gate[i] = "
0630 << fAttrActivationAlpha[direction * 2] << " * (1. - r) / (1. + r);\n";
0631 out << SP << SP << "}\n";
0632 } else if (fAttrActivations[direction * 2] == "HardSigmoid") {
0633 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0634 if (fType == "float") {
0635 out << SP << SP << SP << "float za = " << fAttrActivationAlpha[direction * 2] << " * "
0636 << OpName << "_update_gate[i] + " << fAttrActivationBeta[direction * 2] << ";\n";
0637 out << SP << SP << SP << "float zb = (za > 0.) ? za : 0.;\n";
0638 }
0639 out << SP << SP << SP << SP << OpName << "_update_gate[i] = (zb < 1.) ? zb : 1.;\n";
0640 if (fType == "float") {
0641 out << SP << SP << SP << "float ra = " << fAttrActivationAlpha[direction * 2] << " * "
0642 << OpName << "_reset_gate[i] + " << fAttrActivationBeta[direction * 2] << ";\n";
0643 out << SP << SP << SP << "float rb = (ra > 0.) ? ra : 0.;\n";
0644 }
0645 out << SP << SP << SP << SP << OpName << "_reset_gate[i] = (rb < 1.) ? rb : 1.;\n";
0646 out << SP << SP << "}\n";
0647 } else if (fAttrActivations[direction * 2] == "LeakyRelu") {
0648 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0649 out << SP << SP << SP << "if (" << OpName << "_update_gate[i] < 0.)\n";
0650 out << SP << SP << SP << SP << OpName << "_update_gate[i] = "
0651 << fAttrActivationAlpha[direction * 2] << " * " << OpName << "_update_gate[i];\n";
0652 out << SP << SP << SP << "if (" << OpName << "_reset_gate[i] < 0.)\n";
0653 out << SP << SP << SP << SP << OpName << "_reset_gate[i] = "
0654 << fAttrActivationAlpha[direction * 2] << " * " << OpName << "_reset_gate[i];\n";
0655 out << SP << SP << "}\n";
0656 } else if (fAttrActivations[direction * 2] == "ThresholdRelu") {
0657 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0658 out << SP << SP << SP << "if (" << OpName << "_update_gate[i] < "
0659 << fAttrActivationAlpha[direction * 2] << ")\n";
0660 out << SP << SP << SP << SP << OpName << "_update_gate[i] = 0.;\n";
0661 out << SP << SP << SP << "if (" << OpName << "_reset_gate[i] < "
0662 << fAttrActivationAlpha[direction * 2] << ")\n";
0663 out << SP << SP << SP << SP << OpName << "_reset_gate[i] = 0.;\n";
0664 out << SP << SP << "}";
0665 } else if (fAttrActivations[direction * 2] == "Elu") {
0666 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0667 out << SP << SP << SP << "if (" << OpName << "_update_gate[i] < 0.)\n";
0668 out << SP << SP << SP << SP << OpName << "_update_gate[i] = "
0669 << fAttrActivationAlpha[direction * 2] << " * exp(" << OpName << "_update_gate[i] - 1.);\n";
0670 out << SP << SP << SP << "if (" << OpName << "_reset_gate[i] < 0.)\n";
0671 out << SP << SP << SP << SP << OpName << "_reset_gate[i] = "
0672 << fAttrActivationAlpha[direction * 2] << " * exp(" << OpName << "_reset_gate[i] - 1.);\n";
0673 out << SP << SP << "}\n";
0674 } else if (fAttrActivations[direction * 2] == "Softsign") {
0675 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0676 out << SP << SP << SP << SP << OpName << "_update_gate[i] = " << OpName
0677 << "_update_gate[i] / (1. + abs(" << OpName << "_update_gate[i]));\n";
0678 out << SP << SP << SP << SP << OpName << "_reset_gate[i] = " << OpName
0679 << "_reset_gate[i] / (1. + abs(" << OpName << "_reset_gate[i]));\n";
0680 out << SP << SP << "}\n";
0681 } else {
0682 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0683 out << SP << SP << SP << SP << OpName << "_update_gate[i] = log(1. + exp("
0684 << OpName << "_update_gate[i]));\n";
0685 out << SP << SP << SP << SP << OpName << "_reset_gate[i] = log(1. + exp("
0686 << OpName << "_reset_gate[i]));\n";
0687 out << SP << SP << "}\n";
0688 }
0689
0690 if (fAttrLinearBeforeReset == 0) {
0691 out << SP << SP << "if (seq == 0) {\n";
0692 if (!fNInitial_h.empty()) {
0693
0694 out << SP << SP << SP << "for (size_t i = 0; i < " << size << "; i++) {\n";
0695 out << SP << SP << SP << SP << OpName << "_feedback[i] = " << OpName
0696 << "_reset_gate[i + offset] * " << OpName << "_initial_hidden_state[i];\n";
0697 out << SP << SP << SP << "}\n";
0698 }
0699 out << SP << SP << "} else {\n";
0700
0701 if (direction == 0) {
0702 if (fAttrDirection == "backward") {
0703 out << SP << SP << SP << "size_t previous_offset = (index + 1) * "
0704 << num_directions * batch_size * fAttrHiddenSize << ";\n";
0705 } else {
0706 out << SP << SP << SP << "size_t previous_offset = (seq - 1) * "
0707 << num_directions * batch_size * fAttrHiddenSize << ";\n";
0708 }
0709 } else {
0710 out << SP << SP << SP << "size_t previous_offset = (index + 1) * " << num_directions
0711 * batch_size * fAttrHiddenSize << " + " << batch_size * fAttrHiddenSize << ";\n";
0712 }
0713 out << SP << SP << SP << "for (size_t i = 0; i < " << size << "; i++) {\n";
0714 out << SP << SP << SP << SP << OpName << "_feedback[i] = " << OpName
0715 << "_reset_gate[i + offset] * " << OpName << "_hidden_state[i + previous_offset];\n";
0716 out << SP << SP << SP << "}\n";
0717 out << SP << SP << "}\n";
0718
0719 size_t rh_offset = (direction == 0) ?
0720 2 * fAttrHiddenSize * fAttrHiddenSize : 3 * fAttrHiddenSize * fAttrHiddenSize
0721 + 2 * fAttrHiddenSize * fAttrHiddenSize;
0722 out << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
0723 << OpName << "_n, &" << OpName << "_m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_"
0724 << fNR << " + " << rh_offset << ", &" << OpName << "_n, " << OpName << "_feedback, &" << OpName
0725 << "_n, &" << OpName << "_beta, " << OpName << "_feedback, &" << OpName << "_n);\n";
0726 } else {
0727
0728
0729 size_t rh_offset = (direction == 0)
0730 ? 2 * fAttrHiddenSize * fAttrHiddenSize
0731 : 3 * fAttrHiddenSize * fAttrHiddenSize + 2 * fAttrHiddenSize * fAttrHiddenSize;
0732 out << SP << SP << "if (seq == 0) {\n";
0733 if (!fNInitial_h.empty()) {
0734
0735 out << SP << SP << SP
0736 << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &" << OpName << "_n, &"
0737 << OpName << "_m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << " + "
0738 << rh_offset << ", &" << OpName << "_n, " << OpName << "_initial_hidden_state, &"
0739 << OpName << "_n, &" << OpName << "_beta, " << OpName << "_feedback, &" << OpName << "_n);\n";
0740 }
0741 out << SP << SP << "} else {\n";
0742
0743 if (direction == 0) {
0744 if (fAttrDirection == "backward") {
0745 out << SP << SP << SP << "size_t previous_offset = (index + 1) * "
0746 << num_directions * batch_size * fAttrHiddenSize << ";\n";
0747 } else {
0748 out << SP << SP << SP << "size_t previous_offset = (seq - 1) * "
0749 << num_directions * batch_size * fAttrHiddenSize << ";\n";
0750 }
0751 } else {
0752 out << SP << SP << SP << "size_t previous_offset = (index + 1) * " << num_directions
0753 * batch_size * fAttrHiddenSize << " + " << batch_size * fAttrHiddenSize << ";\n";
0754 }
0755 out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
0756 << OpName << "_n, &" << OpName << "_m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR
0757 << " + " << rh_offset << ", &" << OpName << "_n, " << OpName << "_hidden_state + previous_offset, &"
0758 << OpName << "_n, &" << OpName << "_beta, " << OpName << "_feedback, &" << OpName << "_n);\n";
0759
0760 out << SP << SP << "}\n";
0761
0762 if (!fNB.empty()) {
0763 size_t rbh_offset = (direction == 0) ? 5 * batch_size * seq_length * fAttrHiddenSize
0764 : 11 * batch_size * seq_length * fAttrHiddenSize;
0765 out << SP << SP << "BLAS::saxpy_(&" << OpName << "_feedback_size, &" << OpName
0766 << "_alpha, tensor_" << fNB << " + " << rbh_offset << ", &" << OpName << "_incx, "
0767 << OpName << "_feedback, &" << OpName << "_incy);\n";
0768 }
0769
0770 out << SP << SP << "for (size_t i = 0; i < " << size << "; i++) {\n";
0771 out << SP << SP << SP << OpName << "_feedback[i] *= " << OpName << "_reset_gate[i + offset];\n";
0772 out << SP << SP << "}\n";
0773 }
0774
0775
0776 out << SP << SP << "BLAS::saxpy_(&" << OpName << "_feedback_size, &" << OpName << "_alpha, "
0777 << OpName << "_feedback, &" << OpName << "_incx, " << OpName << "_hidden_gate + offset, &"
0778 << OpName << "_incy);\n";
0779
0780
0781 if (fAttrClip > .0) {
0782 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0783 if (fType == "float") {
0784 out << SP << SP << SP << "float x = (" << OpName << "_hidden_gate[i] > " << -fAttrClip
0785 << ") ? " << OpName << "_hidden_gate[i] : " << -fAttrClip << ";\n";
0786 }
0787 out << SP << SP << SP << OpName << "_hidden_gate[i] = (x < " << fAttrClip << ") ? x : "
0788 << fAttrClip << ";\n";
0789 out << SP << SP << "}\n";
0790 }
0791
0792
0793 if (fAttrActivations[direction * 2 + 1] == "Relu") {
0794 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0795 out << SP << SP << SP << "if (" << OpName << "_hidden_gate[i] < 0.)\n";
0796 out << SP << SP << SP << SP << OpName << "_hidden_gate[i] = 0.;\n";
0797 out << SP << SP << "}\n";
0798 } else if (fAttrActivations[direction * 2 + 1] == "Tanh") {
0799 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0800 if (fType == "float") {
0801 out << SP << SP << SP << "float ex = exp(-2 * " << OpName << "_hidden_gate[i]);\n";
0802 }
0803 out << SP << SP << SP << SP << OpName << "_hidden_gate[i] = (1. - ex) / (1. + ex);\n";
0804 out << SP << SP << "}\n";
0805 } else if (fAttrActivations[direction * 2 + 1] == "Sigmoid") {
0806 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0807 out << SP << SP << SP << SP << OpName << "_hidden_gate[i] = 1. / (1. + exp(-" << OpName
0808 << "_hidden_gate[i]));\n";
0809 out << SP << SP << "}\n";
0810 } else if (fAttrActivations[direction * 2 + 1] == "Affine") {
0811 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0812 out << SP << SP << SP << SP << OpName << "_hidden_gate[i] = "
0813 << fAttrActivationAlpha[direction * 2 + 1] << " * " << OpName << "_hidden_gate[i] + "
0814 << fAttrActivationBeta[direction * 2 + 1] << ";\n";
0815 out << SP << SP << "}\n";
0816 } else if (fAttrActivations[direction * 2 + 1] == "ScaledTanh") {
0817 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0818 if (fType == "float") {
0819 out << SP << SP << SP << "float ex = exp(-2 * " << fAttrActivationBeta[direction * 2 + 1]
0820 << " * "<< OpName << "_hidden_gate[i]);\n";
0821 }
0822 out << SP << SP << SP << SP << OpName << "_hidden_gate[i] = "
0823 << fAttrActivationAlpha[direction * 2 + 1] << " * (1. - ex) / (1. + ex);\n";
0824 out << SP << SP << "}\n";
0825 } else if (fAttrActivations[direction * 2 + 1] == "HardSigmoid") {
0826 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0827 if (fType == "float") {
0828 out << SP << SP << SP << "float a = " << fAttrActivationAlpha[direction * 2 + 1] << " * "
0829 << OpName << "_hidden_gate[i] + " << fAttrActivationBeta[direction * 2 + 1] << ";\n";
0830 out << SP << SP << SP << "float b = (a > 0.) ? a : 0.;\n";
0831 }
0832 out << SP << SP << SP << SP << OpName << "_hidden_gate[i] = (b < 1.) ? b : 1.;\n";
0833 out << SP << SP << "}\n";
0834 } else if (fAttrActivations[direction * 2 + 1] == "LeakyRelu") {
0835 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0836 out << SP << SP << SP << "if (" << OpName << "_hidden_gate[i] < 0.)\n";
0837 out << SP << SP << SP << SP << OpName << "_hidden_gate[i] = "
0838 << fAttrActivationAlpha[direction * 2 + 1] << " * " << OpName << "_hidden_gate[i];\n";
0839 out << SP << SP << "}\n";
0840 } else if (fAttrActivations[direction * 2 + 1] == "ThresholdRelu") {
0841 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0842 out << SP << SP << SP << "if (" << OpName << "_hidden_gate[i] < "
0843 << fAttrActivationAlpha[direction * 2 + 1] << ")\n";
0844 out << SP << SP << SP << SP << OpName << "_hidden_gate[i] = 0.;\n";
0845 out << SP << SP << "}";
0846 } else if (fAttrActivations[direction * 2 + 1] == "Elu") {
0847 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0848 out << SP << SP << SP << "if (" << OpName << "_hidden_gate[i] < 0.)\n";
0849 out << SP << SP << SP << SP << OpName << "_hidden_gate[i] = "
0850 << fAttrActivationAlpha[direction * 2 + 1] << " * exp(" << OpName << "_hidden_gate[i] - 1.);\n";
0851 out << SP << SP << "}\n";
0852 } else if (fAttrActivations[direction * 2 + 1] == "Softsign") {
0853 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0854 out << SP << SP << SP << SP << OpName << "_hidden_gate[i] = " << OpName
0855 << "_hidden_gate[i] / (1. + abs(" << OpName << "_hidden_gate[i]));\n";
0856 out << SP << SP << "}\n";
0857 } else {
0858 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0859 out << SP << SP << SP << SP << OpName << "_hidden_gate[i] = log(1. + exp("
0860 << OpName << "_hidden_gate[i]));\n";
0861 out << SP << SP << "}\n";
0862 }
0863
0864
0865 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0866 out << SP << SP << SP << OpName << "_hidden_state[i] = ( 1. - " << OpName
0867 << "_update_gate[i]) * " << OpName << "_hidden_gate[i];\n";
0868 out << SP << SP << "}\n";
0869
0870 out << SP << SP << "if (seq == 0) {\n";
0871 if (!fNInitial_h.empty()) {
0872
0873 out << SP << SP << SP << "for (size_t i = 0; i < " << size << "; i++) {\n";
0874 out << SP << SP << SP << SP << OpName << "_hidden_state[i + offset] += " << OpName
0875 << "_update_gate[i + offset] * " << OpName << "_initial_hidden_state[i];\n";
0876 out << SP << SP << SP << "}\n";
0877 }
0878 out << SP << SP << "} else {\n";
0879
0880 if (direction == 0) {
0881 if (fAttrDirection == "backward") {
0882 out << SP << SP << SP << "size_t previous_offset = (index + 1) * "
0883 << num_directions * batch_size * fAttrHiddenSize << ";\n";
0884 } else {
0885 out << SP << SP << SP << "size_t previous_offset = (seq - 1) * "
0886 << num_directions * batch_size * fAttrHiddenSize << ";\n";
0887 }
0888 } else {
0889 out << SP << SP << SP << "size_t previous_offset = (index + 1) * "
0890 << num_directions * batch_size * fAttrHiddenSize << " + " << batch_size * fAttrHiddenSize << ";\n";
0891 }
0892 out << SP << SP << SP << "for (size_t i = 0; i < " << size << "; i++) {\n";
0893 out << SP << SP << SP << SP << OpName << "_hidden_state[i + offset] += " << OpName
0894 << "_update_gate[i + offset] * " << OpName << "_hidden_state[i + previous_offset];\n";
0895 out << SP << SP << SP << "}\n";
0896 out << SP << SP << "}\n";
0897
0898 out << SP << "}\n";
0899 }
0900
0901
0902 if (!fNSequence_lens.empty()) {
0903 out << SP << "for (size_t seq = 0; seq < " << seq_length << "; seq++) {\n";
0904 out << SP << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
0905 out << SP << SP << SP << "if (seq >= tensor_" << fNSequence_lens << "[batch]) {\n";
0906 for (size_t direction = 0; direction < num_directions; direction++) {
0907 out << SP << SP << SP << SP << SP << "for (size_t h = 0; h < " << fAttrHiddenSize << "; h++) {\n";
0908 out << SP << SP << SP << SP << SP << SP << OpName << "_hidden_state[seq * "
0909 << num_directions * batch_size * fAttrHiddenSize + direction * batch_size * fAttrHiddenSize
0910 << " + batch * " << fAttrHiddenSize << " + h] = 0.;\n";
0911 out << SP << SP << SP << SP << SP << "}\n";
0912 }
0913 out << SP << SP << SP << "}\n";
0914 out << SP << SP << "}\n";
0915 out << SP << "}\n";
0916 }
0917
0918
0919 if (fAttrLayout == 0) {
0920 if (!fNY_h.empty()) {
0921
0922 if (fNSequence_lens.empty()) {
0923 size_t yh_size = batch_size * fAttrHiddenSize;
0924 if (fAttrDirection == "backward") {
0925 out << SP << "std::copy(" << OpName << "_hidden_state, " << OpName << "_hidden_state + "
0926 << yh_size << ", tensor_" << fNY_h << ");\n";
0927 } else {
0928 size_t offset = (seq_length - 1) * num_directions * batch_size * fAttrHiddenSize;
0929 out << SP << "std::copy(" << OpName << "_hidden_state + " << offset << ", " << OpName
0930 << "_hidden_state + " << offset << " + " << yh_size << ", tensor_" << fNY_h << ");\n";
0931 }
0932 if (num_directions == 2) {
0933 out << SP << "std::copy(" << OpName << "_hidden_state + " << yh_size << ", " << OpName
0934 << "_hidden_state + " << 2 * yh_size << ", tensor_" << fNY_h << " + " << yh_size << ");\n";
0935 }
0936 } else {
0937 if (fAttrDirection == "backward") {
0938 out << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
0939 out << SP << SP << "size_t offset = batch * " << fAttrHiddenSize << ";\n";
0940 out << SP << SP << "std::copy(" << OpName << "_hidden_state + offset, " << OpName
0941 << "_hidden_state + offset + " << fAttrHiddenSize << ", tensor_" << fNY_h << " + offset);\n";
0942 out << SP << "}\n";
0943 } else {
0944 out << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
0945 out << SP << SP << "size_t seq = " << "tensor_" << fNSequence_lens << "[batch] - 1;\n";
0946 out << SP << SP << "size_t offset = seq * " << num_directions * batch_size * fAttrHiddenSize
0947 << " + batch * " << fAttrHiddenSize << ";\n";
0948 out << SP << SP << "size_t yh_offset = batch * " << fAttrHiddenSize << ";\n";
0949 out << SP << SP << "std::copy(" << OpName << "_hidden_state + offset, " << OpName
0950 << "_hidden_state + offset + " << fAttrHiddenSize << ", tensor_" << fNY_h << " + yh_offset);\n";
0951 out << SP << "}\n";
0952 }
0953 if (num_directions == 2) {
0954 out << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
0955 out << SP << SP << "size_t offset = " << batch_size * fAttrHiddenSize
0956 << " + batch * " << fAttrHiddenSize << ";\n";
0957 out << SP << SP << "size_t yh_offset = " << batch_size * fAttrHiddenSize
0958 << " + batch * " << fAttrHiddenSize << ";\n";
0959 out << SP << SP << "std::copy(" << OpName << "_hidden_state + offset, " << OpName
0960 << "_hidden_state + offset + " << fAttrHiddenSize << ", tensor_" << fNY_h << " + yh_offset);\n";
0961 out << SP << "}\n";
0962 }
0963 }
0964 }
0965 } else {
0966 if (!fNY.empty()) {
0967
0968 for (size_t direction = 0; direction < num_directions; direction++) {
0969 out << SP << "for (size_t seq = 0; seq < " << seq_length << "; seq++) {\n";
0970 out << SP << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
0971 out << SP << SP << SP << "size_t offset = seq * " << num_directions * batch_size * fAttrHiddenSize
0972 << " + " << direction * batch_size * fAttrHiddenSize << " + batch * " << fAttrHiddenSize << ";\n";
0973 out << SP << SP << SP << "size_t y_offset = batch * " << seq_length * num_directions * fAttrHiddenSize
0974 << " + seq * " << num_directions * fAttrHiddenSize << " + " << direction * fAttrHiddenSize << ";\n";
0975 out << SP << SP << SP << "std::copy(" << OpName << "_hidden_state + offset, " << OpName
0976 << "_hidden_state + offset + " << fAttrHiddenSize << ", tensor_" << fNY << " + y_offset);\n";
0977 out << SP << SP << "}\n";
0978 out << SP << "}\n";
0979 }
0980 }
0981 if (!fNY_h.empty()) {
0982
0983 if (fAttrDirection == "backward") {
0984 out << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
0985 out << SP << SP << "size_t offset = batch * " << fAttrHiddenSize << ";\n";
0986 out << SP << SP << "size_t yh_offset = batch * " << num_directions * fAttrHiddenSize << ";\n";
0987 out << SP << SP << "std::copy(" << OpName << "_hidden_state + offset, " << OpName
0988 << "_hidden_state + offset + " << fAttrHiddenSize << ", tensor_" << fNY_h << " + yh_offset);\n";
0989 out << SP << "}\n";
0990 } else {
0991 out << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
0992 if (fNSequence_lens.empty()) {
0993 out << SP << SP << "size_t seq = " << seq_length - 1 << ";\n";
0994 } else {
0995 out << SP << SP << "size_t seq = " << "tensor_" << fNSequence_lens << "[batch] - 1;\n";
0996 }
0997 out << SP << SP << "size_t offset = seq * " << num_directions * batch_size * fAttrHiddenSize
0998 << " + batch * " << fAttrHiddenSize << ";\n";
0999 out << SP << SP << "size_t yh_offset = batch * " << num_directions * fAttrHiddenSize << ";\n";
1000 out << SP << SP << "std::copy(" << OpName << "_hidden_state + offset, " << OpName
1001 << "_hidden_state + offset + " << fAttrHiddenSize << ", tensor_" << fNY_h << " + yh_offset);\n";
1002 out << SP << "}\n";
1003 }
1004 if (num_directions == 2) {
1005 out << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
1006 out << SP << SP << "size_t offset = " << batch_size * fAttrHiddenSize << " + batch * "
1007 << fAttrHiddenSize << ";\n";
1008 out << SP << SP << "size_t yh_offset = batch * " << num_directions * fAttrHiddenSize << " + "
1009 << fAttrHiddenSize << ";\n";
1010 out << SP << SP << "std::copy(" << OpName << "_hidden_state + offset, " << OpName
1011 << "_hidden_state + offset + " << fAttrHiddenSize << ", tensor_" << fNY_h << " + yh_offset);\n";
1012 out << SP << "}\n";
1013 }
1014 }
1015 }
1016
1017 return out.str();
1018 }
1019
1020 }
1021 }
1022 }
1023
1024 #endif