File indexing completed on 2025-01-18 10:11:07
0001 #ifndef TMVA_SOFIE_ROPERATOR_LSTM_I
0002 #define TMVA_SOFIE_ROPERATOR_LSTM_I
0003
0004 namespace TMVA {
0005 namespace Experimental {
0006 namespace SOFIE {
0007
0008 template<typename T>
0009 auto ROperator_LSTM<T>::TypeInference(std::vector<ETensorType> input)
0010 -> std::vector<ETensorType> {
0011 ETensorType out = input[0];
0012 return {out, out};
0013 }
0014
0015 template<typename T>
0016 auto ROperator_LSTM<T>::ShapeInference(std::vector<std::vector<size_t>> input)
0017 -> std::vector<std::vector<size_t>> {
0018 size_t num_directions = input[1][0];
0019 size_t hidden_size = input[1][1] / 4;
0020 if (fAttrLayout == 0) {
0021 size_t seq_length = input[0][0];
0022 size_t batch_size = input[0][1];
0023 std::vector<std::vector<size_t>> ret(
0024 {{seq_length, num_directions, batch_size, hidden_size},
0025 {num_directions, batch_size, hidden_size},
0026 {num_directions, batch_size, hidden_size}});
0027 return ret;
0028 } else {
0029 size_t batch_size = input[0][0];
0030 size_t seq_length = input[0][1];
0031 std::vector<std::vector<size_t>> ret(
0032 {{batch_size, seq_length, num_directions, hidden_size},
0033 {batch_size, num_directions, hidden_size},
0034 {batch_size, num_directions, hidden_size}});
0035 return ret;
0036 }
0037 }
0038
0039 template<typename T>
0040 auto ROperator_LSTM<T>::Initialize(RModel &model)
0041 -> void {
0042 fUseSession = model.UseSession();
0043
0044 if (!model.CheckIfTensorAlreadyExist(fNX)) {
0045 throw std::runtime_error("TMVA SOFIE LSTM Op input tensor " + fNX + " is not found in model.");
0046 }
0047 fShapeX = model.GetTensorShape(fNX);
0048 if (fShapeX.size() != 3) {
0049 throw std::runtime_error("TMVA SOFIE LSTM Op input tensor " + fNX + " is not of 3 dimensions.");
0050 }
0051 if (!model.CheckIfTensorAlreadyExist(fNW)) {
0052 throw std::runtime_error("TMVA SOFIE LSTM Op input tensor " + fNW + " is not found in model.");
0053 }
0054 fShapeW = model.GetTensorShape(fNW);
0055 if (fShapeW.size() != 3) {
0056 throw std::runtime_error("TMVA SOFIE LSTM Op input tensor " + fNW + " is not of 3 dimensions.");
0057 }
0058 if (!model.CheckIfTensorAlreadyExist(fNR)) {
0059 throw std::runtime_error("TMVA SOFIE LSTM Op input tensor " + fNR + " is not found in model.");
0060 }
0061 fShapeR = model.GetTensorShape(fNR);
0062 if (fShapeR.size() != 3) {
0063 throw std::runtime_error("TMVA SOFIE LSTM Op input tensor " + fNR + " is not of 3 dimensions.");
0064 }
0065 if (!fNB.empty()) {
0066 if (!model.CheckIfTensorAlreadyExist(fNB)) {
0067 throw std::runtime_error("TMVA SOFIE LSTM op input tensor " + fNB + " is not found in model.");
0068 }
0069 fShapeB = model.GetTensorShape(fNB);
0070 if (fShapeB.size() != 2 && fShapeB.size() != 5) {
0071 throw std::runtime_error("TMVA SOFIE LSTM op input tensor " + fNB + " is not of 2 or 5 dimensions.");
0072 }
0073 if (fShapeB.size() == 2) {
0074
0075 auto original_data = model.GetInitializedTensorData(fNB);
0076 size_t num_directions = fShapeW[0];
0077 size_t seq_length = (fAttrLayout == 0)? fShapeX[0] : fShapeX[1];
0078 size_t batch_size = (fAttrLayout == 0)? fShapeX[1] : fShapeX[0];
0079 if (fType == "float") {
0080 float *original_bias = static_cast<float*>(original_data.get());
0081 float *new_bias = new float[4 * num_directions * seq_length * batch_size * fAttrHiddenSize];
0082 for (size_t gate = 0; gate < 4; gate++) {
0083 float sum[fAttrHiddenSize];
0084 for (size_t direction = 0; direction < num_directions; direction++) {
0085 size_t offset = direction * 8 * fAttrHiddenSize + gate * fAttrHiddenSize;
0086 for (size_t h = 0; h < fAttrHiddenSize; h++) {
0087 sum[h] = original_bias[offset + h] + original_bias[offset + h + 4 * fAttrHiddenSize];
0088 }
0089 for (size_t seq = 0; seq < seq_length; seq++) {
0090 for (size_t batch = 0; batch < batch_size; batch++) {
0091 size_t bias_offset = gate * num_directions * seq_length * batch_size * fAttrHiddenSize
0092 + direction * seq_length * batch_size * fAttrHiddenSize
0093 + seq * batch_size * fAttrHiddenSize + batch * fAttrHiddenSize;
0094 std::copy(sum, sum + fAttrHiddenSize, new_bias + bias_offset);
0095 }
0096 }
0097 }
0098 }
0099 std::vector<size_t> new_bias_shape = {4, num_directions, seq_length, batch_size, fAttrHiddenSize};
0100 std::shared_ptr<void> new_bias_ptr(new_bias, std::default_delete<float[]>());
0101 model.UpdateInitializedTensor(fNB, model.GetTensorType(fNB), new_bias_shape, new_bias_ptr);
0102 fShapeB = model.GetTensorShape(fNB);
0103 }
0104 }
0105 }
0106 if (!fNSequence_lens.empty()) {
0107 if (!model.CheckIfTensorAlreadyExist(fNSequence_lens)) {
0108 throw std::runtime_error("TMVA SOFIE LSTM Op input tensor " +
0109 fNSequence_lens +
0110 "is not found in model.");
0111 }
0112 fShapeSequence_lens = model.GetTensorShape(fNSequence_lens);
0113 if (fShapeSequence_lens.size() != 1) {
0114 throw std::runtime_error("TMVA SOFIE LSTM Op input tensor " +
0115 fNSequence_lens +
0116 " is not of 1 dimension.");
0117 }
0118 }
0119 if (!fNInitial_h.empty()) {
0120 if (!model.CheckIfTensorAlreadyExist(fNInitial_h)) {
0121 throw std::runtime_error("TMVA SOFIE LSTM Op input tensor " +
0122 fNInitial_h + " is not found in model.");
0123 }
0124 fShapeInitial_h = model.GetTensorShape(fNInitial_h);
0125 if (fShapeInitial_h.size() != 3) {
0126 throw std::runtime_error("TMVA SOFIE LSTM Op input tensor " +
0127 fNInitial_h + " is not of 3 dimensions.");
0128 }
0129 }
0130 if (!fNInitial_c.empty()) {
0131 if (!model.CheckIfTensorAlreadyExist(fNInitial_c)) {
0132 throw std::runtime_error("TMVA SOFIE LSTM Op input tensor " +
0133 fNInitial_c + " is not found in model.");
0134 }
0135 fShapeInitial_c = model.GetTensorShape(fNInitial_c);
0136 if (fShapeInitial_c.size() != 3) {
0137 throw std::runtime_error("TMVA SOFIE LSTM Op input tensor " +
0138 fNInitial_c + " is not of 3 dimensions.");
0139 }
0140 }
0141 if (!fNP.empty()) {
0142 if (!model.CheckIfTensorAlreadyExist(fNP)) {
0143 throw std::runtime_error("TMVA SOFIE LSTM op input tensor " + fNP + " is not found in model.");
0144 }
0145 fShapeP = model.GetTensorShape(fNP);
0146 if (fShapeP.size() != 2 && fShapeP.size() != 4) {
0147 throw std::runtime_error("TMVA SOFIE LSTM op input tensor " + fNP + " is not of 2 or 4 dimensions.");
0148 }
0149 if (fShapeP.size() == 2) {
0150
0151 auto original_data = model.GetInitializedTensorData(fNP);
0152 size_t num_directions = fShapeW[0];
0153 size_t batch_size = (fAttrLayout == 0)? fShapeX[1] : fShapeX[0];
0154 if (fType == "float") {
0155 float *original_p = static_cast<float*>(original_data.get());
0156 float *new_p = new float[num_directions * 3 * batch_size * fAttrHiddenSize];
0157 for (size_t direction = 0; direction < num_directions; direction++) {
0158 for (size_t gate = 0; gate < 3; gate++) {
0159 size_t p_offset = direction * 3 * fAttrHiddenSize + gate * fAttrHiddenSize;
0160 for (size_t batch = 0; batch < batch_size; batch++) {
0161 size_t offset = direction * 3 * batch_size * fAttrHiddenSize
0162 + gate * batch_size * fAttrHiddenSize + batch * fAttrHiddenSize;
0163 std::copy(original_p + p_offset, original_p + p_offset + fAttrHiddenSize,
0164 new_p + offset);
0165 }
0166 }
0167 }
0168 std::vector<size_t> new_p_shape = {num_directions, 3, batch_size, fAttrHiddenSize};
0169 std::shared_ptr<void> new_p_ptr(new_p, std::default_delete<float[]>());
0170 model.UpdateInitializedTensor(fNP, model.GetTensorType(fNP), new_p_shape, new_p_ptr);
0171 fShapeP = model.GetTensorShape(fNP);
0172 }
0173 }
0174 }
0175 if (!fNY.empty()) {
0176 fShapeY = ShapeInference({fShapeX, fShapeW})[0];
0177 if (!model.CheckIfTensorAlreadyExist(fNY)) {
0178 model.AddIntermediateTensor(fNY, model.GetTensorType(fNX), fShapeY);
0179 }
0180 }
0181 if (!fNY_h.empty()) {
0182 fShapeY_h = ShapeInference({fShapeX, fShapeW})[1];
0183 if (!model.CheckIfTensorAlreadyExist(fNY_h)) {
0184 model.AddIntermediateTensor(fNY_h, model.GetTensorType(fNX), fShapeY_h);
0185 }
0186 }
0187 if (!fNY_c.empty()) {
0188 fShapeY_c = ShapeInference({fShapeX, fShapeW})[2];
0189 if (!model.CheckIfTensorAlreadyExist(fNY_c)) {
0190 model.AddIntermediateTensor(fNY_c, model.GetTensorType(fNX), fShapeY_c);
0191 }
0192 }
0193
0194 for (auto &activation : fAttrActivations) {
0195 if (activation != "Relu" && activation != "Tanh" &&
0196 activation != "Sigmoid" && activation != "Affine" &&
0197 activation != "LeakyRelu" && activation != "ThresholdRelu" &&
0198 activation != "ScaledTanh" && activation != "HardSigmoid" &&
0199 activation != "Elu" && activation != "Softsign" &&
0200 activation != "Softplus") {
0201 throw std::runtime_error("TMVA SOFIE - Activation function " +
0202 activation + " not implemented");
0203 }
0204 }
0205 if (fAttrDirection != "forward" && fAttrDirection != "backward" &&
0206 fAttrDirection != "bidirectional") {
0207 throw std::runtime_error(
0208 "TMVA SOFIE - Invalid LSTM direction fAttrDirection = " +
0209 fAttrDirection);
0210 }
0211 if (4 * fAttrHiddenSize != fShapeW[1]) {
0212 throw std::runtime_error(
0213 "TMVA SOFIE - fAttrHiddenSize must be equal to " +
0214 std::to_string(fShapeW[1] / 4));
0215 }
0216 if (fAttrInputForget > 1) {
0217 throw std::runtime_error(
0218 "TMVA SOFIE - fAttrInputForget = " + std::to_string(fAttrInputForget)
0219 + " must be 0 or 1.");
0220 }
0221 if (fAttrLayout > 1) {
0222 throw std::runtime_error("TMVA SOFIE - Layout fAttrLayout = " +
0223 std::to_string(fAttrLayout) +
0224 " must be 0 (timewise) or 1 (batchwise)");
0225 }
0226 if (fAttrActivations.empty()) {
0227 if (fAttrDirection == "bidirectional") {
0228 fAttrActivations = {"Sigmoid", "Tanh", "Tanh", "Sigmoid", "Tanh", "Tanh"};
0229 } else {
0230 fAttrActivations = {"Sigmoid", "Tanh", "Tanh"};
0231 }
0232 }
0233 }
0234
0235
0236 template <typename T>
0237 std::string ROperator_LSTM<T>::GenerateSessionMembersCode(std::string opName)
0238 {
0239 opName = "op_" + opName;
0240 std::stringstream out;
0241
0242 size_t num_directions = fShapeW[0];
0243 size_t seq_length = (fAttrLayout == 0) ? fShapeX[0] : fShapeX[1];
0244 size_t batch_size = (fAttrLayout == 0) ? fShapeX[1] : fShapeX[0];
0245 size_t input_size = fShapeX[2];
0246
0247 if (fAttrLayout != 0) {
0248 out << "std::vector<" << fType << "> fVec_" << opName << "_input = std::vector<" << fType << ">("
0249 << seq_length * batch_size * input_size << ");\n";
0250 out << "std::vector<" << fType << "> fVec_" << opName << "_initial_hidden_state = std::vector<" << fType << ">("
0251 << num_directions * batch_size * fAttrHiddenSize << ");\n";
0252 out << "std::vector<" << fType << "> fVec_" << opName << "_initial_cell_state = std::vector<" << fType << ">("
0253 << num_directions * batch_size * fAttrHiddenSize << ");\n";
0254 }
0255
0256 size_t ff_size = seq_length * batch_size * fAttrHiddenSize;
0257 out << "std::vector<" << fType << "> fVec_" << opName << "_ff_input_gate = std::vector<" << fType << ">(" << ff_size << ");\n";
0258 out << "std::vector<" << fType << "> fVec_" << opName << "_ff_output_gate = std::vector<" << fType << ">(" << ff_size << ");\n";
0259 out << "std::vector<" << fType << "> fVec_" << opName << "_ff_cell_gate = std::vector<" << fType << ">(" << ff_size << ");\n";
0260 if (fAttrInputForget == 0)
0261 out << "std::vector<" << fType << "> fVec_" << opName << "_ff_forget_gate = std::vector<" << fType << ">(" << ff_size << ");\n";
0262
0263 size_t hs_size = seq_length * num_directions * batch_size * fAttrHiddenSize;
0264 out << "std::vector<" << fType << "> fVec_" << opName << "_input_gate = std::vector<" << fType << ">(" << hs_size << ");\n";
0265 out << "std::vector<" << fType << "> fVec_" << opName << "_output_gate = std::vector<" << fType << ">(" << hs_size << ");\n";
0266 out << "std::vector<" << fType << "> fVec_" << opName << "_cell_gate = std::vector<" << fType << ">(" << hs_size << ");\n";
0267 if (fAttrInputForget == 0)
0268 out << "std::vector<" << fType << "> fVec_" << opName << "_forget_gate = std::vector<" << fType << ">(" << hs_size << ");\n";
0269
0270 out << "std::vector<" << fType << "> fVec_" << opName << "_cell_state = std::vector<" << fType << ">(" << hs_size << ");\n";
0271 out << "std::vector<" << fType << "> fVec_" << opName << "_new_cell_state = std::vector<" << fType << ">(" << hs_size << ");\n";
0272
0273 if (fAttrLayout != 0 || fNY.empty()) {
0274 out << "std::vector<" << fType << "> fVec_" << opName << "_hidden_state = std::vector<" << fType << ">(" << hs_size << ");\n";
0275 }
0276
0277 out << "\n";
0278
0279 return out.str();
0280 }
0281
0282 template<typename T>
0283 auto ROperator_LSTM<T>::Generate(std::string OpName)
0284 -> std::string {
0285 OpName = "op_" + OpName;
0286 std::stringstream out;
0287
0288 size_t seq_length = (fAttrLayout == 0) ? fShapeX[0] : fShapeX[1];
0289 size_t batch_size = (fAttrLayout == 0) ? fShapeX[1] : fShapeX[0];
0290 size_t input_size = fShapeX[2];
0291 size_t num_directions = fShapeW[0];
0292
0293
0294 if (fAttrLayout == 0) {
0295 out << SP << fType << " *" << OpName << "_input = tensor_" << fNX << ";\n";
0296 } else {
0297 if (fUseSession)
0298 out << SP << fType << " * " << OpName << "_input = fVec_" << OpName << "_input.data();\n";
0299 else
0300 out << SP << fType << " " << OpName << "_input[" << seq_length * batch_size * input_size << "] = {0};\n";
0301
0302 out << SP << "for(size_t seq = 0; seq < " << seq_length << "; seq++) {\n";
0303 out << SP << SP << "for(size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
0304 out << SP << SP << SP << "for(size_t i = 0; i < " << input_size << "; i++) {\n";
0305 out << SP << SP << SP << SP << OpName << "_input[seq * " << batch_size * input_size
0306 << " + batch * " << input_size << " + i] = " << "tensor_" << fNX << "[batch * "
0307 << seq_length * input_size << " + seq * " << input_size << " + i];\n";
0308 out << SP << SP << SP << "}\n";
0309 out << SP << SP << "}\n";
0310 out << SP << "}\n";
0311 }
0312
0313
0314 if (!fNInitial_h.empty()) {
0315 if (fAttrLayout == 0) {
0316 out << SP << fType << " *" << OpName << "_initial_hidden_state = " << " tensor_"
0317 << fNInitial_h << ";\n";
0318 } else {
0319 if (fUseSession)
0320 out << SP << fType << " * " << OpName << "_initial_hidden_state = fVec_" << OpName
0321 << "_initial_hidden_state.data();\n";
0322 else
0323 out << SP << fType << " " << OpName << "_initial_hidden_state[" << num_directions * batch_size *
0324 fAttrHiddenSize << "] = {0};\n";
0325
0326 for (size_t direction = 0; direction < num_directions; direction++) {
0327 out << SP << "for(size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
0328 out << SP << SP << "for(size_t h = 0; h < " << fAttrHiddenSize << "; h++) {\n";
0329 out << SP << SP << SP << OpName << "_initial_hidden_state["
0330 << direction * batch_size * fAttrHiddenSize << " + batch * " << fAttrHiddenSize
0331 << " + h] = tensor_" << fNInitial_h << "[batch * " << num_directions * fAttrHiddenSize
0332 << " + " << direction * fAttrHiddenSize << " + h];\n";
0333 out << SP << SP << "}\n";
0334 out << SP << "}\n";
0335 }
0336 }
0337 }
0338
0339
0340 if (!fNInitial_c.empty()) {
0341 if (fAttrLayout == 0) {
0342 out << SP << fType << " *" << OpName << "_initial_cell_state = " << " tensor_"
0343 << fNInitial_c << ";\n";
0344 } else {
0345 if (fUseSession)
0346 out << SP << fType << " * " << OpName << "_initial_cell_state = fVec_" << OpName
0347 << "_initial_cell_state.data();\n";
0348 else
0349 out << SP << fType << " " << OpName << "_initial_cell_state[" << num_directions * batch_size *
0350 fAttrHiddenSize << "] = {0};\n";
0351
0352 for (size_t direction = 0; direction < num_directions; direction++) {
0353 out << SP << "for(size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
0354 out << SP << SP << "for(size_t h = 0; h < " << fAttrHiddenSize << "; h++) {\n";
0355 out << SP << SP << SP << OpName << "_initial_cell_state["
0356 << direction * batch_size * fAttrHiddenSize << " + batch * " << fAttrHiddenSize
0357 << " + h] = tensor_" << fNInitial_c << "[batch * " << num_directions * fAttrHiddenSize
0358 << " + " << direction * fAttrHiddenSize << " + h];\n";
0359 out << SP << SP << "}\n";
0360 out << SP << "}\n";
0361 }
0362 }
0363 }
0364
0365
0366 size_t ff_size = seq_length * batch_size * fAttrHiddenSize;
0367 if (fUseSession) {
0368 out << SP << fType << " * " << OpName << "_ff_input_gate = fVec_" << OpName << "_ff_input_gate.data();\n";
0369 out << SP << fType << " * " << OpName << "_ff_output_gate = fVec_" << OpName << "_ff_output_gate.data();\n";
0370 out << SP << fType << " * " << OpName << "_ff_cell_gate = fVec_" << OpName << "_ff_cell_gate.data();\n";
0371 if (fAttrInputForget == 0) {
0372 out << SP << fType << " * " << OpName << "_ff_forget_gate = fVec_" << OpName << "_ff_forget_gate.data();\n";
0373 }
0374 } else {
0375 out << SP << fType << " " << OpName << "_ff_input_gate[" << ff_size << "] = {0};\n";
0376 out << SP << fType << " " << OpName << "_ff_output_gate[" << ff_size << "] = {0};\n";
0377 out << SP << fType << " " << OpName << "_ff_cell_gate[" << ff_size << "] = {0};\n";
0378 if (fAttrInputForget == 0) {
0379 out << SP << fType << " " << OpName << "_ff_forget_gate[" << ff_size << "] = {0};\n";
0380 }
0381 }
0382
0383 size_t hidden_state_size = seq_length * num_directions * batch_size * fAttrHiddenSize;
0384 if (fUseSession) {
0385 out << SP << fType << " * " << OpName << "_input_gate = fVec_" << OpName << "_input_gate.data();\n";
0386 out << SP << fType << " * " << OpName << "_output_gate = fVec_" << OpName << "_output_gate.data();\n";
0387 out << SP << fType << " * " << OpName << "_cell_gate = fVec_" << OpName << "_cell_gate.data();\n";
0388 if (fAttrInputForget == 0) {
0389 out << SP << fType << " * " << OpName << "_forget_gate = fVec_" << OpName << "_forget_gate.data();\n";
0390 }
0391 } else {
0392 out << SP << fType << " " << OpName << "_input_gate[" << hidden_state_size << "] = {0};\n";
0393 out << SP << fType << " " << OpName << "_output_gate[" << hidden_state_size << "] = {0};\n";
0394 out << SP << fType << " " << OpName << "_cell_gate[" << hidden_state_size << "] = {0};\n";
0395 if (fAttrInputForget == 0) {
0396 out << SP << fType << " " << OpName << "_forget_gate[" << hidden_state_size << "] = {0};\n";
0397 }
0398 }
0399
0400 if (fUseSession) {
0401 out << SP << fType << " * " << OpName << "_cell_state = fVec_" << OpName << "_cell_state.data();\n";
0402 out << SP << fType << " * " << OpName << "_new_cell_state = fVec_" << OpName << "_new_cell_state.data();\n";
0403 } else {
0404 out << SP << fType << " " << OpName << "_cell_state[" << hidden_state_size << "] = {0};\n";
0405 out << SP << fType << " " << OpName << "_new_cell_state[" << hidden_state_size << "] = {0};\n";
0406 }
0407
0408
0409 if (fAttrLayout == 0 && !fNY.empty()) {
0410 out << SP << fType << " *" << OpName << "_hidden_state = tensor_" << fNY << ";\n";
0411 } else {
0412 if (fUseSession) {
0413 out << SP << fType << " * " << OpName << "_hidden_state = fVec_" << OpName << "_hidden_state.data();\n";
0414 } else {
0415 out << SP << fType << " " << OpName << "_hidden_state[" << hidden_state_size << "] = {0};\n";
0416 }
0417 }
0418
0419 out << SP << "char " << OpName << "_transA = 'N';\n";
0420 out << SP << "char " << OpName << "_transB = 'T';\n";
0421 out << SP << "int " << OpName << "_m = " << seq_length * batch_size << ";\n";
0422 out << SP << "int " << OpName << "_n = " << fAttrHiddenSize << ";\n";
0423 out << SP << "int " << OpName << "_k = " << input_size << ";\n";
0424 if (fType == "float") {
0425 out << SP << fType << " " << OpName << "_alpha = 1.;\n";
0426 out << SP << fType << " " << OpName << "_beta = 0.;\n";
0427 }
0428 if (!fNB.empty()) {
0429 out << SP << "int " << OpName << "_bias_size = " << seq_length * batch_size * fAttrHiddenSize << ";\n";
0430 out << SP << "int " << OpName << "_incx = 1;\n";
0431 out << SP << "int " << OpName << "_incy = 1;\n";
0432 }
0433
0434 for (size_t direction = 0; direction < num_directions; direction++) {
0435 if (direction == 0) {
0436 if (fType == "float") {
0437
0438 out << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
0439 << OpName <<"_n, &" << OpName << "_m, &" << OpName << "_k, &" << OpName << "_alpha, tensor_"
0440 << fNW << ", &" << OpName << "_k, " << OpName << "_input, &" << OpName << "_k, &"
0441 << OpName << "_beta, " << OpName << "_ff_input_gate, &" << OpName << "_n);\n";
0442
0443 size_t wo_offset = fAttrHiddenSize * input_size;
0444 out << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
0445 << OpName <<"_n, &" << OpName << "_m, &" << OpName << "_k, &" << OpName << "_alpha, tensor_"
0446 << fNW << " + " << wo_offset << ", &" << OpName << "_k, " << OpName << "_input, &"
0447 << OpName << "_k, &" << OpName << "_beta, " << OpName << "_ff_output_gate, &" << OpName << "_n);\n";
0448
0449 size_t wc_offset = 3 * fAttrHiddenSize * input_size;
0450 out << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
0451 << OpName <<"_n, &" << OpName << "_m, &" << OpName << "_k, &" << OpName << "_alpha, tensor_"
0452 << fNW << " + " << wc_offset << ", &" << OpName << "_k, " << OpName << "_input, &"
0453 << OpName << "_k, &" << OpName << "_beta, " << OpName << "_ff_cell_gate, &" << OpName << "_n);\n";
0454 }
0455 } else {
0456 if (fType == "float") {
0457
0458 size_t wi_offset = 4 * fAttrHiddenSize * input_size;
0459 out << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
0460 << OpName <<"_n, &" << OpName << "_m, &" << OpName << "_k, &" << OpName << "_alpha, tensor_"
0461 << fNW << " + " << wi_offset << ", &" << OpName << "_k, " << OpName << "_input, &"
0462 << OpName << "_k, &" << OpName << "_beta, " << OpName << "_ff_input_gate, &" << OpName << "_n);\n";
0463
0464 size_t wo_offset = 4 * fAttrHiddenSize * input_size + 1 * fAttrHiddenSize * input_size;
0465 out << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
0466 << OpName <<"_n, &" << OpName << "_m, &" << OpName << "_k, &" << OpName << "_alpha, tensor_"
0467 << fNW << " + " << wo_offset << ", &" << OpName << "_k, " << OpName << "_input, &"
0468 << OpName << "_k, &" << OpName << "_beta, " << OpName << "_ff_output_gate, &" << OpName << "_n);\n";
0469
0470 size_t wc_offset = 4 * fAttrHiddenSize * input_size + 3 * fAttrHiddenSize * input_size;
0471 out << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
0472 << OpName <<"_n, &" << OpName << "_m, &" << OpName << "_k, &" << OpName << "_alpha, tensor_"
0473 << fNW << " + " << wc_offset << ", &" << OpName << "_k, " << OpName << "_input, &"
0474 << OpName << "_k, &" << OpName << "_beta, " << OpName << "_ff_cell_gate, &" << OpName << "_n);\n";
0475 }
0476 }
0477 if (fAttrInputForget == 0) {
0478
0479 if (direction == 0) {
0480 if (fType == "float") {
0481 size_t wf_offset = 2 * fAttrHiddenSize * input_size;
0482 out << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
0483 << OpName <<"_n, &" << OpName << "_m, &" << OpName << "_k, &" << OpName << "_alpha, tensor_"
0484 << fNW << " + " << wf_offset << ", &" << OpName << "_k, " << OpName << "_input, &"
0485 << OpName << "_k, &" << OpName << "_beta, " << OpName << "_ff_forget_gate, &" << OpName << "_n);\n";
0486 }
0487 } else {
0488 if (fType == "float") {
0489 size_t wf_offset = 4 * fAttrHiddenSize * input_size + 2 * fAttrHiddenSize * input_size;
0490 out << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
0491 << OpName <<"_n, &" << OpName << "_m, &" << OpName << "_k, &" << OpName << "_alpha, tensor_"
0492 << fNW << " + " << wf_offset << ", &" << OpName << "_k, " << OpName << "_input, &"
0493 << OpName << "_k, &" << OpName << "_beta, " << OpName << "_ff_forget_gate, &" << OpName << "_n);\n";
0494 }
0495 }
0496 }
0497
0498
0499 if (!fNB.empty()) {
0500 if (direction == 0) {
0501 if (fType == "float") {
0502
0503 out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_"
0504 << fNB << ", &" << OpName << "_incx, " << OpName << "_ff_input_gate, &" << OpName << "_incy);\n";
0505
0506 size_t bo_offset = seq_length * batch_size * fAttrHiddenSize;
0507 out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_"
0508 << fNB << " + " << bo_offset << ", &" << OpName << "_incx, " << OpName << "_ff_output_gate, &"
0509 << OpName << "_incy);\n";
0510
0511 size_t bc_offset = 3 * seq_length * batch_size * fAttrHiddenSize;
0512 out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_"
0513 << fNB << " + " << bc_offset << ", &" << OpName << "_incx, " << OpName << "_ff_cell_gate, &"
0514 << OpName << "_incy);\n";
0515 }
0516 } else {
0517 if (fType == "float") {
0518
0519 size_t bi_offset = 4 * seq_length * batch_size * fAttrHiddenSize;
0520 out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_"
0521 << fNB << " + " << bi_offset << ", &" << OpName << "_incx, " << OpName << "_ff_input_gate, &"
0522 << OpName << "_incy);\n";
0523
0524 size_t bo_offset = 4 * seq_length * batch_size * fAttrHiddenSize
0525 + seq_length * batch_size * fAttrHiddenSize;
0526 out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_"
0527 << fNB << " + " << bo_offset << ", &" << OpName << "_incx, " << OpName << "_ff_output_gate, &"
0528 << OpName << "_incy);\n";
0529
0530 size_t bc_offset = 4 * num_directions * seq_length * batch_size * fAttrHiddenSize
0531 + 3 * seq_length * batch_size * fAttrHiddenSize;
0532 out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_"
0533 << fNB << " + " << bc_offset << ", &" << OpName << "_incx, " << OpName << "_ff_cell_gate, &"
0534 << OpName << "_incy);\n";
0535 }
0536 }
0537 if (fAttrInputForget == 0) {
0538
0539 if (direction == 0) {
0540 if (fType == "float") {
0541 size_t bo_offset = 2 * seq_length * batch_size * fAttrHiddenSize;
0542 out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_"
0543 << fNB << " + " << bo_offset << ", &" << OpName << "_incx, " << OpName << "_ff_forget_gate, &"
0544 << OpName << "_incy);\n";
0545 }
0546 } else {
0547 if (fType == "float") {
0548 size_t bo_offset = 4 * seq_length * batch_size * fAttrHiddenSize
0549 + 2 * seq_length * batch_size * fAttrHiddenSize;
0550 out << SP << "BLAS::saxpy_(&" << OpName << "_bias_size, &" << OpName << "_alpha, tensor_"
0551 << fNB << " + " << bo_offset << ", &" << OpName << "_incx, " << OpName << "_ff_forget_gate, &"
0552 << OpName << "_incy);\n";
0553 }
0554 }
0555 }
0556 }
0557
0558
0559
0560 out << SP << "for (size_t seq = 0; seq < " << seq_length << "; seq++) {\n";
0561 out << SP << SP << "size_t ff_offset = seq * " << batch_size * fAttrHiddenSize << ";\n";
0562 if (direction == 0) {
0563 out << SP << SP << "size_t gate_offset = seq * " << num_directions * batch_size * fAttrHiddenSize
0564 << ";\n";
0565 } else {
0566 out << SP << SP << "size_t gate_offset = seq * " << num_directions * batch_size * fAttrHiddenSize
0567 << " + " << batch_size * fAttrHiddenSize << ";\n";
0568 }
0569 size_t ff_seq_size = batch_size * fAttrHiddenSize;
0570 out << SP << SP << "std::copy(" << OpName << "_ff_input_gate + ff_offset, " << OpName
0571 << "_ff_input_gate + ff_offset + " << ff_seq_size << ", " << OpName << "_input_gate + gate_offset);\n";
0572 out << SP << SP << "std::copy(" << OpName << "_ff_output_gate + ff_offset, " << OpName
0573 << "_ff_output_gate + ff_offset + " << ff_seq_size << ", " << OpName << "_output_gate + gate_offset);\n";
0574 out << SP << SP << "std::copy(" << OpName << "_ff_cell_gate + ff_offset, " << OpName
0575 << "_ff_cell_gate + ff_offset + " << ff_seq_size << ", " << OpName << "_cell_gate + gate_offset);\n";
0576 if (fAttrInputForget == 0) {
0577 out << SP << SP << "std::copy(" << OpName << "_ff_forget_gate + ff_offset, " << OpName
0578 << "_ff_forget_gate + ff_offset + " << ff_seq_size << ", " << OpName << "_forget_gate + gate_offset);\n";
0579 }
0580 out << SP << "}\n";
0581
0582 out << SP << "for (size_t seq = 0; seq < " << seq_length << "; seq++) {\n";
0583 if (fAttrDirection == "backward" || direction == 1) {
0584 out << SP << SP << "size_t index = " << seq_length - 1 << " - seq;\n";
0585 } else {
0586 out << SP << SP << "size_t index = seq;\n";
0587 }
0588 out << SP << SP << "int m2 = " << batch_size << ";\n";
0589 if (direction == 0) {
0590 out << SP << SP << "size_t offset = index * " << num_directions * batch_size * fAttrHiddenSize
0591 << ";\n";
0592 } else {
0593 out << SP << SP << "size_t offset = index * " << num_directions * batch_size * fAttrHiddenSize
0594 << " + " << batch_size * fAttrHiddenSize << ";\n";
0595 }
0596 size_t size = batch_size * fAttrHiddenSize;
0597
0598 out << SP << SP << "if (seq == 0) {\n";
0599 if (!fNInitial_h.empty()) {
0600 if (direction == 0) {
0601 if (fType == "float") {
0602 out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
0603 << OpName << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << ", &"
0604 << OpName << "_n, " << OpName << "_initial_hidden_state, &" << OpName << "_n, &" << OpName
0605 << "_alpha, " << OpName << "_input_gate + offset, &" << OpName << "_n);\n";
0606 size_t ro_offset = fAttrHiddenSize * fAttrHiddenSize;
0607 out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
0608 << OpName << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << " + "
0609 << ro_offset << ", &" << OpName << "_n, " << OpName << "_initial_hidden_state, &" << OpName
0610 << "_n, &" << OpName << "_alpha, " << OpName << "_output_gate + offset, &" << OpName << "_n);\n";
0611 size_t rc_offset = 3 * fAttrHiddenSize * fAttrHiddenSize;
0612 out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
0613 << OpName << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << " + "
0614 << rc_offset << ", &" << OpName << "_n, " << OpName << "_initial_hidden_state, &" << OpName
0615 << "_n, &" << OpName << "_alpha, " << OpName << "_cell_gate + offset, &" << OpName << "_n);\n";
0616 if (fAttrInputForget == 0) {
0617 size_t rf_offset = 2 * fAttrHiddenSize * fAttrHiddenSize;
0618 out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
0619 << OpName << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << " + "
0620 << rf_offset << ", &" << OpName << "_n, " << OpName << "_initial_hidden_state, &" << OpName
0621 << "_n, &" << OpName << "_alpha, " << OpName << "_forget_gate + offset, &" << OpName << "_n);\n";
0622 }
0623 }
0624 } else {
0625 if (fType == "float") {
0626 size_t ri_offset = 4 * fAttrHiddenSize * fAttrHiddenSize;
0627 out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
0628 << OpName << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << " + "
0629 << ri_offset << ", &" << OpName << "_n, " << OpName << "_initial_hidden_state, &" << OpName
0630 << "_n, &" << OpName << "_alpha, " << OpName << "_input_gate + offset, &" << OpName << "_n);\n";
0631 size_t ro_offset = 4 * fAttrHiddenSize * fAttrHiddenSize + 1 * fAttrHiddenSize * fAttrHiddenSize;
0632 out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
0633 << OpName << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << " + "
0634 << ro_offset << ", &" << OpName << "_n, " << OpName << "_initial_hidden_state, &" << OpName
0635 << "_n, &" << OpName << "_alpha, " << OpName << "_output_gate + offset, &" << OpName << "_n);\n";
0636 size_t rc_offset = 4 * fAttrHiddenSize * fAttrHiddenSize + 3 * fAttrHiddenSize * fAttrHiddenSize;
0637 out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
0638 << OpName << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << " + "
0639 << rc_offset << ", &" << OpName << "_n, " << OpName << "_initial_hidden_state, &" << OpName
0640 << "_n, &" << OpName << "_alpha, " << OpName << "_cell_gate + offset, &" << OpName << "_n);\n";
0641 if (fAttrInputForget == 0) {
0642 size_t rf_offset = 4 * fAttrHiddenSize * fAttrHiddenSize + 2 * fAttrHiddenSize * fAttrHiddenSize;
0643 out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
0644 << OpName << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << " + "
0645 << rf_offset << ", &" << OpName << "_n, " << OpName << "_initial_hidden_state, &" << OpName
0646 << "_n, &" << OpName << "_alpha, " << OpName << "_forget_gate + offset, &" << OpName << "_n);\n";
0647 }
0648 }
0649 }
0650 }
0651 out << SP << SP << "} else {\n";
0652
0653 if (direction == 0) {
0654 if (fAttrDirection == "backward") {
0655 out << SP << SP << SP << "size_t previous_offset = (index + 1) * "
0656 << num_directions * batch_size * fAttrHiddenSize << ";\n";
0657 } else {
0658 out << SP << SP << SP << "size_t previous_offset = (seq - 1) * "
0659 << num_directions * batch_size * fAttrHiddenSize << ";\n";
0660 }
0661 if (fType == "float") {
0662 out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
0663 << OpName << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << ", &"
0664 << OpName << "_n, " << OpName << "_hidden_state + previous_offset, &" << OpName << "_n, &"
0665 << OpName << "_alpha, " << OpName << "_input_gate + offset, &" << OpName << "_n);\n";
0666 size_t ro_offset = 1 * fAttrHiddenSize * fAttrHiddenSize;
0667 out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
0668 << OpName << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << " + "
0669 << ro_offset << ", &" << OpName << "_n, " << OpName << "_hidden_state + previous_offset, &"
0670 << OpName << "_n, &" << OpName << "_alpha, " << OpName << "_output_gate + offset, &"
0671 << OpName << "_n);\n";
0672 size_t rc_offset = 3 * fAttrHiddenSize * fAttrHiddenSize;
0673 out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
0674 << OpName << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << " + "
0675 << rc_offset << ", &" << OpName << "_n, " << OpName << "_hidden_state + previous_offset, &"
0676 << OpName << "_n, &" << OpName << "_alpha, " << OpName << "_cell_gate + offset, &"
0677 << OpName << "_n);\n";
0678 if (fAttrInputForget == 0) {
0679 size_t rf_offset = 2 * fAttrHiddenSize * fAttrHiddenSize;
0680 out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
0681 << OpName << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << " + "
0682 << rf_offset << ", &" << OpName << "_n, " << OpName << "_hidden_state + previous_offset, &"
0683 << OpName << "_n, &" << OpName << "_alpha, " << OpName << "_forget_gate + offset, &"
0684 << OpName << "_n);\n";
0685 }
0686 }
0687 } else {
0688 out << SP << SP << SP << "size_t previous_offset = (index + 1) * "
0689 << num_directions * batch_size * fAttrHiddenSize << " + " << batch_size * fAttrHiddenSize << ";\n";
0690 if (fType == "float") {
0691 size_t ri_offset = 4 * fAttrHiddenSize * fAttrHiddenSize;
0692 out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
0693 << OpName << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << " + "
0694 << ri_offset << ", &" << OpName << "_n, " << OpName << "_hidden_state + previous_offset, &"
0695 << OpName << "_n, &" << OpName << "_alpha, " << OpName << "_input_gate + offset, &"
0696 << OpName << "_n);\n";
0697 size_t ro_offset = 4 * fAttrHiddenSize * fAttrHiddenSize + fAttrHiddenSize * fAttrHiddenSize;
0698 out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
0699 << OpName << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << " + "
0700 << ro_offset << ", &" << OpName << "_n, " << OpName << "_hidden_state + previous_offset, &"
0701 << OpName << "_n, &" << OpName << "_alpha, " << OpName << "_output_gate + offset, &"
0702 << OpName << "_n);\n";
0703 size_t rc_offset = 4 * fAttrHiddenSize * fAttrHiddenSize + 3 * fAttrHiddenSize * fAttrHiddenSize;
0704 out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
0705 << OpName << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << " + "
0706 << rc_offset << ", &" << OpName << "_n, " << OpName << "_hidden_state + previous_offset, &"
0707 << OpName << "_n, &" << OpName << "_alpha, " << OpName << "_cell_gate + offset, &"
0708 << OpName << "_n);\n";
0709 if (fAttrInputForget == 0) {
0710 size_t rf_offset = 4 * fAttrHiddenSize * fAttrHiddenSize + 2 * fAttrHiddenSize * fAttrHiddenSize;
0711 out << SP << SP << SP << "BLAS::sgemm_(&" << OpName << "_transB, &" << OpName << "_transA, &"
0712 << OpName << "_n, &m2, &" << OpName << "_n, &" << OpName << "_alpha, tensor_" << fNR << " + "
0713 << rf_offset << ", &" << OpName << "_n, " << OpName << "_hidden_state + previous_offset, &"
0714 << OpName << "_n, &" << OpName << "_alpha, " << OpName << "_forget_gate + offset, &"
0715 << OpName << "_n);\n";
0716 }
0717 }
0718 }
0719 out << SP << SP << "}\n";
0720
0721
0722 if (fAttrClip > .0) {
0723 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0724 if (fType == "float") {
0725 out << SP << SP << SP << "float x = (" << OpName << "_cell_gate[i] > " << -fAttrClip << ") ? "
0726 << OpName << "_cell_gate[i] : " << -fAttrClip << ";\n";
0727 }
0728 out << SP << SP << SP << OpName << "_cell_gate[i] = (x < " << fAttrClip << ") ? x : "
0729 << fAttrClip << ";\n";
0730 out << SP << SP << "}\n";
0731 }
0732
0733 if (fAttrActivations[direction * 3 + 1] == "Relu") {
0734 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0735 out << SP << SP << SP << "if (" << OpName << "_cell_gate[i] < 0.)\n";
0736 out << SP << SP << SP << SP << OpName << "_cell_gate[i] = 0.;\n";
0737 out << SP << SP << "}\n";
0738 } else if (fAttrActivations[direction * 3 + 1] == "Tanh") {
0739 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0740 if (fType == "float") {
0741 out << SP << SP << SP << "float ex = exp(-2 * " << OpName << "_cell_gate[i]);\n";
0742 }
0743 out << SP << SP << SP << SP << OpName << "_cell_gate[i] = (1. - ex) / (1. + ex);\n";
0744 out << SP << SP << "}\n";
0745 } else if (fAttrActivations[direction * 3 + 1] == "Sigmoid") {
0746 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0747 out << SP << SP << SP << SP << OpName << "_cell_gate[i] = 1. / (1. + exp(-" << OpName
0748 << "_cell_gate[i]));\n";
0749 out << SP << SP << "}\n";
0750 } else if (fAttrActivations[direction * 3 + 1] == "Affine") {
0751 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0752 out << SP << SP << SP << SP << OpName << "_cell_gate[i] = "
0753 << fAttrActivationAlpha[direction * 3 + 1] << " * " << OpName << "_cell_gate[i] + "
0754 << fAttrActivationBeta[direction * 3 + 1] << ";\n";
0755 out << SP << SP << "}\n";
0756 } else if (fAttrActivations[direction * 3 + 1] == "ScaledTanh") {
0757 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0758 if (fType == "float") {
0759 out << SP << SP << SP << "float ex = exp(-2 * " << fAttrActivationBeta[direction * 3 + 1]
0760 << " * "<< OpName << "_cell_gate[i]);\n";
0761 }
0762 out << SP << SP << SP << SP << OpName << "_cell_gate[i] = "
0763 << fAttrActivationAlpha[direction * 3 + 1] << " * (1. - ex) / (1. + ex);\n";
0764 out << SP << SP << "}\n";
0765 } else if (fAttrActivations[direction * 3 + 1] == "HardSigmoid") {
0766 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0767 if (fType == "float") {
0768 out << SP << SP << SP << "float a = " << fAttrActivationAlpha[direction * 3 + 1] << " * "
0769 << OpName << "_cell_gate[i] + " << fAttrActivationBeta[direction * 3 + 1] << ";\n";
0770 out << SP << SP << SP << "float b = (a > 0.) ? a : 0.;\n";
0771 }
0772 out << SP << SP << SP << SP << OpName << "_cell_gate[i] = (b < 1.) ? b : 1.;\n";
0773 out << SP << SP << "}\n";
0774 } else if (fAttrActivations[direction * 3 + 1] == "LeakyRelu") {
0775 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0776 out << SP << SP << SP << "if (" << OpName << "_cell_gate[i] < 0.)\n";
0777 out << SP << SP << SP << SP << OpName << "_cell_gate[i] = "
0778 << fAttrActivationAlpha[direction * 3 + 1] << " * " << OpName << "_cell_gate[i];\n";
0779 out << SP << SP << "}\n";
0780 } else if (fAttrActivations[direction * 3 + 1] == "ThresholdRelu") {
0781 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0782 out << SP << SP << SP << "if (" << OpName << "_cell_gate[i] < "
0783 << fAttrActivationAlpha[direction * 3 + 1] << ")\n";
0784 out << SP << SP << SP << SP << OpName << "_cell_gate[i] = 0.;\n";
0785 out << SP << SP << "}";
0786 } else if (fAttrActivations[direction * 3 + 1] == "Elu") {
0787 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0788 out << SP << SP << SP << "if (" << OpName << "_cell_gate[i] < 0.)\n";
0789 out << SP << SP << SP << SP << OpName << "_cell_gate[i] = "
0790 << fAttrActivationAlpha[direction * 3 + 1] << " * exp(" << OpName << "_cell_gate[i] - 1.);\n";
0791 out << SP << SP << "}\n";
0792 } else if (fAttrActivations[direction * 3 + 1] == "Softsign") {
0793 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0794 out << SP << SP << SP << SP << OpName << "_cell_gate[i] = " << OpName
0795 << "_cell_gate[i] / (1. + abs(" << OpName << "_cell_gate[i]));\n";
0796 out << SP << SP << "}\n";
0797 } else {
0798 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0799 out << SP << SP << SP << SP << OpName << "_cell_gate[i] = log(1. + exp("
0800 << OpName << "_cell_gate[i]));\n";
0801 out << SP << SP << "}\n";
0802 }
0803
0804
0805 if (!fNP.empty()) {
0806
0807 out << SP << SP << "if (seq == 0) {\n";
0808 if (!fNInitial_c.empty()) {
0809 if (direction == 0) {
0810 out << SP << SP << SP << "for (size_t i = 0; i < " << size << "; i++) {\n";
0811 out << SP << SP << SP << SP << OpName << "_input_gate[i + offset] += tensor_" << fNP
0812 << "[i] * " << OpName << "_initial_cell_state[i];\n";
0813 out << SP << SP << SP << "}\n";
0814 if (fAttrInputForget == 0) {
0815 size_t pf_offset = batch_size * fAttrHiddenSize;
0816 out << SP << SP << SP << "for (size_t i = 0; i < " << size << "; i++) {\n";
0817 out << SP << SP << SP << SP << OpName << "_forget_gate[i + offset] += tensor_" << fNP
0818 << "[i + " << pf_offset << "] * " << OpName << "_initial_cell_state[i];\n";
0819 out << SP << SP << SP << "}\n";
0820 }
0821 } else {
0822 size_t pi_offset = 3 * batch_size * fAttrHiddenSize;
0823 size_t initial_c_offset = batch_size * fAttrHiddenSize;
0824 out << SP << SP << SP << "for (size_t i = 0; i < " << size << "; i++) {\n";
0825 out << SP << SP << SP << SP << OpName << "_input_gate[i + offset] += tensor_" << fNP
0826 << "[i + " << pi_offset << "] * " << OpName << "_initial_cell_state[i + " << initial_c_offset
0827 << "];\n";
0828 out << SP << SP << SP << "}\n";
0829 if (fAttrInputForget == 0) {
0830 size_t pf_offset = 3 * batch_size * fAttrHiddenSize + batch_size * fAttrHiddenSize;
0831 out << SP << SP << SP << "for (size_t i = 0; i < " << size << "; i++) {\n";
0832 out << SP << SP << SP << SP << OpName << "_forget_gate[i + offset] += tensor_" << fNP
0833 << "[i + " << pf_offset << "] * " << OpName << "_initial_cell_state[i + " << initial_c_offset
0834 << "];\n";
0835 out << SP << SP << SP << "}\n";
0836 }
0837 }
0838 }
0839 out << SP << SP << "} else {\n";
0840 if (direction == 0) {
0841 if (fAttrDirection == "backward") {
0842 out << SP << SP << SP << "size_t c_offset = (index + 1) * "
0843 << num_directions * batch_size * fAttrHiddenSize << ";\n";
0844 } else {
0845 out << SP << SP << SP << "size_t c_offset = (seq - 1) * "
0846 << num_directions * batch_size * fAttrHiddenSize << ";\n";
0847 }
0848 out << SP << SP << SP << "for (size_t i = 0; i < " << size << "; i++) {\n";
0849 out << SP << SP << SP << SP << OpName << "_input_gate[i + offset] += tensor_" << fNP
0850 << "[i] * " << OpName << "_cell_state[i + c_offset];\n";
0851 out << SP << SP << SP << "}\n";
0852 if (fAttrInputForget == 0) {
0853 size_t pf_offset = batch_size * fAttrHiddenSize;
0854 out << SP << SP << SP << "for (size_t i = 0; i < " << size << "; i++) {\n";
0855 out << SP << SP << SP << SP << OpName << "_forget_gate[i + offset] += tensor_" << fNP
0856 << "[i + " << pf_offset << "] * " << OpName << "_cell_state[i + c_offset];\n";
0857 out << SP << SP << SP << "}\n";
0858 }
0859 } else {
0860 size_t pi_offset = 3 * batch_size * fAttrHiddenSize;
0861 out << SP << SP << SP << "size_t c_offset = (index + 1) * "
0862 << num_directions * batch_size * fAttrHiddenSize << " + " << batch_size * fAttrHiddenSize << ";\n";
0863 out << SP << SP << SP << "for (size_t i = 0; i < " << size << "; i++) {\n";
0864 out << SP << SP << SP << SP << OpName << "_input_gate[i + offset] += tensor_" << fNP
0865 << "[i + " << pi_offset << "] * " << OpName << "_cell_state[i + c_offset];\n";
0866 out << SP << SP << SP << "}\n";
0867 if (fAttrInputForget == 0) {
0868 size_t pf_offset = 3 * batch_size * fAttrHiddenSize + batch_size * fAttrHiddenSize;
0869 out << SP << SP << SP << "for (size_t i = 0; i < " << size << "; i++) {\n";
0870 out << SP << SP << SP << SP << OpName << "_forget_gate[i + offset] += tensor_" << fNP
0871 << "[i + " << pf_offset << "] * " << OpName << "_cell_state[i + c_offset];\n";
0872 out << SP << SP << SP << "}\n";
0873 }
0874 }
0875 out << SP << SP << "}\n";
0876 }
0877
0878
0879 if (fAttrClip > .0) {
0880 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0881 if (fType == "float") {
0882 out << SP << SP << SP << "float x = (" << OpName << "_input_gate[i] > " << -fAttrClip << ") ? "
0883 << OpName << "_input_gate[i] : " << -fAttrClip << ";\n";
0884 }
0885 out << SP << SP << SP << OpName << "_input_gate[i] = (x < " << fAttrClip << ") ? x : "
0886 << fAttrClip << ";\n";
0887 out << SP << SP << "}\n";
0888 }
0889
0890 if (fAttrActivations[direction * 3] == "Relu") {
0891 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0892 out << SP << SP << SP << "if (" << OpName << "_input_gate[i] < 0.)\n";
0893 out << SP << SP << SP << SP << OpName << "_input_gate[i] = 0.;\n";
0894 out << SP << SP << "}\n";
0895 } else if (fAttrActivations[direction * 3] == "Tanh") {
0896 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0897 if (fType == "float") {
0898 out << SP << SP << SP << "float ex = exp(-2 * " << OpName << "_input_gate[i]);\n";
0899 }
0900 out << SP << SP << SP << SP << OpName << "_input_gate[i] = (1. - ex) / (1. + ex);\n";
0901 out << SP << SP << "}\n";
0902 } else if (fAttrActivations[direction * 3] == "Sigmoid") {
0903 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0904 out << SP << SP << SP << SP << OpName << "_input_gate[i] = 1. / (1. + exp(-" << OpName
0905 << "_input_gate[i]));\n";
0906 out << SP << SP << "}\n";
0907 } else if (fAttrActivations[direction * 3] == "Affine") {
0908 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0909 out << SP << SP << SP << SP << OpName << "_input_gate[i] = "
0910 << fAttrActivationAlpha[direction * 3] << " * " << OpName << "_input_gate[i] + "
0911 << fAttrActivationBeta[direction * 3] << ";\n";
0912 out << SP << SP << "}\n";
0913 } else if (fAttrActivations[direction * 3] == "ScaledTanh") {
0914 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0915 if (fType == "float") {
0916 out << SP << SP << SP << "float ex = exp(-2 * " << fAttrActivationBeta[direction * 3]
0917 << " * "<< OpName << "_input_gate[i]);\n";
0918 }
0919 out << SP << SP << SP << SP << OpName << "_input_gate[i] = "
0920 << fAttrActivationAlpha[direction * 3] << " * (1. - ex) / (1. + ex);\n";
0921 out << SP << SP << "}\n";
0922 } else if (fAttrActivations[direction * 3] == "HardSigmoid") {
0923 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0924 if (fType == "float") {
0925 out << SP << SP << SP << "float a = " << fAttrActivationAlpha[direction * 3] << " * "
0926 << OpName << "_input_gate[i] + " << fAttrActivationBeta[direction * 3] << ";\n";
0927 out << SP << SP << SP << "float b = (a > 0.) ? a : 0.;\n";
0928 }
0929 out << SP << SP << SP << SP << OpName << "_input_gate[i] = (b < 1.) ? b : 1.;\n";
0930 out << SP << SP << "}\n";
0931 } else if (fAttrActivations[direction * 3] == "LeakyRelu") {
0932 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0933 out << SP << SP << SP << "if (" << OpName << "_input_gate[i] < 0.)\n";
0934 out << SP << SP << SP << SP << OpName << "_input_gate[i] = "
0935 << fAttrActivationAlpha[direction * 3] << " * " << OpName << "_input_gate[i];\n";
0936 out << SP << SP << "}\n";
0937 } else if (fAttrActivations[direction * 3] == "ThresholdRelu") {
0938 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0939 out << SP << SP << SP << "if (" << OpName << "_input_gate[i] < "
0940 << fAttrActivationAlpha[direction * 3] << ")\n";
0941 out << SP << SP << SP << SP << OpName << "_input_gate[i] = 0.;\n";
0942 out << SP << SP << "}";
0943 } else if (fAttrActivations[direction * 3] == "Elu") {
0944 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0945 out << SP << SP << SP << "if (" << OpName << "_input_gate[i] < 0.)\n";
0946 out << SP << SP << SP << SP << OpName << "_input_gate[i] = "
0947 << fAttrActivationAlpha[direction * 3] << " * exp(" << OpName << "_input_gate[i] - 1.);\n";
0948 out << SP << SP << "}\n";
0949 } else if (fAttrActivations[direction * 3] == "Softsign") {
0950 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0951 out << SP << SP << SP << SP << OpName << "_input_gate[i] = " << OpName
0952 << "_input_gate[i] / (1. + abs(" << OpName << "_input_gate[i]));\n";
0953 out << SP << SP << "}\n";
0954 } else {
0955 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0956 out << SP << SP << SP << SP << OpName << "_input_gate[i] = log(1. + exp("
0957 << OpName << "_input_gate[i]));\n";
0958 out << SP << SP << "}\n";
0959 }
0960
0961 if (fAttrInputForget == 0) {
0962
0963 if (fAttrClip > .0) {
0964 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0965 if (fType == "float") {
0966 out << SP << SP << SP << "float x = (" << OpName << "_forget_gate[i] > "
0967 << -fAttrClip << ") ? " << OpName << "_forget_gate[i] : " << -fAttrClip << ";\n";
0968 }
0969 out << SP << SP << SP << OpName << "_forget_gate[i] = (x < " << fAttrClip
0970 << ") ? x : " << fAttrClip << ";\n";
0971 out << SP << SP << "}\n";
0972 }
0973
0974 if (fAttrActivations[direction * 3] == "Relu") {
0975 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0976 out << SP << SP << SP << "if (" << OpName << "_forget_gate[i] < 0.)\n";
0977 out << SP << SP << SP << SP << OpName << "_forget_gate[i] = 0.;\n";
0978 out << SP << SP << "}\n";
0979 } else if (fAttrActivations[direction * 3] == "Tanh") {
0980 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0981 if (fType == "float") {
0982 out << SP << SP << SP << "float ex = exp(-2 * " << OpName << "_forget_gate[i]);\n";
0983 }
0984 out << SP << SP << SP << SP << OpName << "_forget_gate[i] = (1. - ex) / (1. + ex);\n";
0985 out << SP << SP << "}\n";
0986 } else if (fAttrActivations[direction * 3] == "Sigmoid") {
0987 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0988 out << SP << SP << SP << SP << OpName << "_forget_gate[i] = 1. / (1. + exp(-"
0989 << OpName << "_forget_gate[i]));\n";
0990 out << SP << SP << "}\n";
0991 } else if (fAttrActivations[direction * 3] == "Affine") {
0992 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0993 out << SP << SP << SP << SP << OpName << "_forget_gate[i] = "
0994 << fAttrActivationAlpha[direction * 3] << " * " << OpName << "_forget_gate[i] + "
0995 << fAttrActivationBeta[direction * 3] << ";\n";
0996 out << SP << SP << "}\n";
0997 } else if (fAttrActivations[direction * 3] == "ScaledTanh") {
0998 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
0999 if (fType == "float") {
1000 out << SP << SP << SP << "float ex = exp(-2 * " << fAttrActivationBeta[direction * 3]
1001 << " * "<< OpName << "_forget_gate[i]);\n";
1002 }
1003 out << SP << SP << SP << SP << OpName << "_forget_gate[i] = "
1004 << fAttrActivationAlpha[direction * 3] << " * (1. - ex) / (1. + ex);\n";
1005 out << SP << SP << "}\n";
1006 } else if (fAttrActivations[direction * 3] == "HardSigmoid") {
1007 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1008 if (fType == "float") {
1009 out << SP << SP << SP << "float a = " << fAttrActivationAlpha[direction * 3] << " * "
1010 << OpName << "_forget_gate[i] + " << fAttrActivationBeta[direction * 3] << ";\n";
1011 out << SP << SP << SP << "float b = (a > 0.) ? a : 0.;\n";
1012 }
1013 out << SP << SP << SP << SP << OpName << "_forget_gate[i] = (b < 1.) ? b : 1.;\n";
1014 out << SP << SP << "}\n";
1015 } else if (fAttrActivations[direction * 3] == "LeakyRelu") {
1016 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1017 out << SP << SP << SP << "if (" << OpName << "_forget_gate[i] < 0.)\n";
1018 out << SP << SP << SP << SP << OpName << "_forget_gate[i] = "
1019 << fAttrActivationAlpha[direction * 3] << " * " << OpName << "_forget_gate[i];\n";
1020 out << SP << SP << "}\n";
1021 } else if (fAttrActivations[direction * 3] == "ThresholdRelu") {
1022 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1023 out << SP << SP << SP << "if (" << OpName << "_forget_gate[i] < "
1024 << fAttrActivationAlpha[direction * 3] << ")\n";
1025 out << SP << SP << SP << SP << OpName << "_forget_gate[i] = 0.;\n";
1026 out << SP << SP << "}";
1027 } else if (fAttrActivations[direction * 3] == "Elu") {
1028 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1029 out << SP << SP << SP << "if (" << OpName << "_forget_gate[i] < 0.)\n";
1030 out << SP << SP << SP << SP << OpName << "_forget_gate[i] = "
1031 << fAttrActivationAlpha[direction * 3] << " * exp(" << OpName << "_forget_gate[i] - 1.);\n";
1032 out << SP << SP << "}\n";
1033 } else if (fAttrActivations[direction * 3] == "Softsign") {
1034 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1035 out << SP << SP << SP << SP << OpName << "_forget_gate[i] = " << OpName
1036 << "_forget_gate[i] / (1. + abs(" << OpName << "_forget_gate[i]));\n";
1037 out << SP << SP << "}\n";
1038 } else {
1039 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1040 out << SP << SP << SP << SP << OpName << "_forget_gate[i] = log(1. + exp("
1041 << OpName << "_forget_gate[i]));\n";
1042 out << SP << SP << "}\n";
1043 }
1044 }
1045
1046
1047 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1048 out << SP << SP << SP << OpName << "_cell_state[i] = " << OpName << "_input_gate[i] * "
1049 << OpName << "_cell_gate[i];\n";
1050 out << SP << SP << "}\n";
1051
1052 if (fAttrInputForget == 0) {
1053 out << SP << SP << "if (seq == 0) {\n";
1054 if (!fNInitial_c.empty()) {
1055
1056 out << SP << SP << SP << "for (size_t i = 0; i < " << size << "; i++) {\n";
1057 out << SP << SP << SP << SP << OpName << "_cell_state[i + offset] += "
1058 << OpName << "_forget_gate[i + offset] * " << OpName << "_initial_cell_state[i];\n";
1059 out << SP << SP << SP << "}\n";
1060 }
1061 out << SP << SP << "} else {\n";
1062
1063 if (direction == 0) {
1064 if (fAttrDirection == "backward") {
1065 out << SP << SP << SP << "size_t previous_offset = (index + 1) * "
1066 << num_directions * batch_size * fAttrHiddenSize << ";\n";
1067 } else {
1068 out << SP << SP << SP << "size_t previous_offset = (seq - 1) * "
1069 << num_directions * batch_size * fAttrHiddenSize << ";\n";
1070 }
1071 } else {
1072 out << SP << SP << SP << "size_t previous_offset = (index + 1) * "
1073 << num_directions * batch_size * fAttrHiddenSize << " + " << batch_size * fAttrHiddenSize << ";\n";
1074 }
1075 out << SP << SP << SP << "for (size_t i = 0; i < " << size << "; i++) {\n";
1076 out << SP << SP << SP << SP << OpName << "_cell_state[i + offset] += "
1077 << OpName << "_forget_gate[i + offset] * " << OpName << "_cell_state[i + previous_offset];\n";
1078 out << SP << SP << SP << "}\n";
1079 out << SP << SP << "}\n";
1080 }
1081
1082 if (!fNP.empty()) {
1083
1084 if (direction == 0) {
1085 size_t p_offset = 2 * batch_size * fAttrHiddenSize;
1086 out << SP << SP << SP << "for (size_t i = 0; i < " << size << "; i++) {\n";
1087 out << SP << SP << SP << SP << OpName << "_output_gate[i + offset] += tensor_"
1088 << fNP << "[i + " << p_offset << "] * " << OpName << "_cell_state[i + offset];\n";
1089 out << SP << SP << SP << "}\n";
1090 } else {
1091 size_t p_offset = 3 * batch_size * fAttrHiddenSize + 2 * batch_size * fAttrHiddenSize;
1092 out << SP << SP << SP << "for (size_t i = 0; i < " << size << "; i++) {\n";
1093 out << SP << SP << SP << SP << OpName << "_output_gate[i + offset] += tensor_"
1094 << fNP << "[i + " << p_offset << "] * " << OpName << "_cell_state[i + offset];\n";
1095 out << SP << SP << SP << "}\n";
1096 }
1097 }
1098
1099
1100 if (fAttrClip > .0) {
1101 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1102 if (fType == "float") {
1103 out << SP << SP << SP << "float x = (" << OpName << "_output_gate[i] > " << -fAttrClip
1104 << ") ? " << OpName << "_output_gate[i] : " << -fAttrClip << ";\n";
1105 }
1106 out << SP << SP << SP << OpName << "_output_gate[i] = (x < " << fAttrClip << ") ? x : "
1107 << fAttrClip << ";\n";
1108 out << SP << SP << "}\n";
1109 }
1110
1111 if (fAttrActivations[direction * 3] == "Relu") {
1112 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1113 out << SP << SP << SP << "if (" << OpName << "_output_gate[i] < 0.)\n";
1114 out << SP << SP << SP << SP << OpName << "_output_gate[i] = 0.;\n";
1115 out << SP << SP << "}\n";
1116 } else if (fAttrActivations[direction * 3] == "Tanh") {
1117 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1118 if (fType == "float") {
1119 out << SP << SP << SP << "float ex = exp(-2 * " << OpName << "_output_gate[i]);\n";
1120 }
1121 out << SP << SP << SP << SP << OpName << "_output_gate[i] = (1. - ex) / (1. + ex);\n";
1122 out << SP << SP << "}\n";
1123 } else if (fAttrActivations[direction * 3] == "Sigmoid") {
1124 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1125 out << SP << SP << SP << SP << OpName << "_output_gate[i] = 1. / (1. + exp(-" << OpName
1126 << "_output_gate[i]));\n";
1127 out << SP << SP << "}\n";
1128 } else if (fAttrActivations[direction * 3] == "Affine") {
1129 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1130 out << SP << SP << SP << SP << OpName << "_output_gate[i] = "
1131 << fAttrActivationAlpha[direction * 3] << " * " << OpName << "_output_gate[i] + "
1132 << fAttrActivationBeta[direction * 3] << ";\n";
1133 out << SP << SP << "}\n";
1134 } else if (fAttrActivations[direction * 3] == "ScaledTanh") {
1135 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1136 if (fType == "float") {
1137 out << SP << SP << SP << "float ex = exp(-2 * " << fAttrActivationBeta[direction * 3]
1138 << " * "<< OpName << "_output_gate[i]);\n";
1139 }
1140 out << SP << SP << SP << SP << OpName << "_output_gate[i] = "
1141 << fAttrActivationAlpha[direction * 3] << " * (1. - ex) / (1. + ex);\n";
1142 out << SP << SP << "}\n";
1143 } else if (fAttrActivations[direction * 3] == "HardSigmoid") {
1144 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1145 if (fType == "float") {
1146 out << SP << SP << SP << "float a = " << fAttrActivationAlpha[direction * 3] << " * "
1147 << OpName << "_output_gate[i] + " << fAttrActivationBeta[direction * 3] << ";\n";
1148 out << SP << SP << SP << "float b = (a > 0.) ? a : 0.;\n";
1149 }
1150 out << SP << SP << SP << SP << OpName << "_output_gate[i] = (b < 1.) ? b : 1.;\n";
1151 out << SP << SP << "}\n";
1152 } else if (fAttrActivations[direction * 3] == "LeakyRelu") {
1153 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1154 out << SP << SP << SP << "if (" << OpName << "_output_gate[i] < 0.)\n";
1155 out << SP << SP << SP << SP << OpName << "_output_gate[i] = "
1156 << fAttrActivationAlpha[direction * 3] << " * " << OpName << "_output_gate[i];\n";
1157 out << SP << SP << "}\n";
1158 } else if (fAttrActivations[direction * 3] == "ThresholdRelu") {
1159 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1160 out << SP << SP << SP << "if (" << OpName << "_output_gate[i] < "
1161 << fAttrActivationAlpha[direction * 3] << ")\n";
1162 out << SP << SP << SP << SP << OpName << "_output_gate[i] = 0.;\n";
1163 out << SP << SP << "}";
1164 } else if (fAttrActivations[direction * 3] == "Elu") {
1165 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1166 out << SP << SP << SP << "if (" << OpName << "_output_gate[i] < 0.)\n";
1167 out << SP << SP << SP << SP << OpName << "_output_gate[i] = "
1168 << fAttrActivationAlpha[direction * 3] << " * exp(" << OpName << "_output_gate[i] - 1.);\n";
1169 out << SP << SP << "}\n";
1170 } else if (fAttrActivations[direction * 3] == "Softsign") {
1171 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1172 out << SP << SP << SP << SP << OpName << "_output_gate[i] = " << OpName
1173 << "_output_gate[i] / (1. + abs(" << OpName << "_output_gate[i]));\n";
1174 out << SP << SP << "}\n";
1175 } else {
1176 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1177 out << SP << SP << SP << SP << OpName << "_output_gate[i] = log(1. + exp("
1178 << OpName << "_output_gate[i]));\n";
1179 out << SP << SP << "}\n";
1180 }
1181
1182
1183 out << SP << SP << "std::copy(" << OpName << "_cell_state + offset, " << OpName
1184 << "_cell_state + offset + " << size << ", "<< OpName << "_new_cell_state + offset);\n";
1185
1186 if (fAttrClip > .0) {
1187 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1188 if (fType == "float") {
1189 out << SP << SP << SP << "float x = (" << OpName << "_new_cell_state[i] > " << -fAttrClip
1190 << ") ? " << OpName << "_new_cell_state[i] : " << -fAttrClip << ";\n";
1191 }
1192 out << SP << SP << SP << OpName << "_new_cell_state[i] = (x < " << fAttrClip << ") ? x : "
1193 << fAttrClip << ";\n";
1194 out << SP << SP << "}\n";
1195 }
1196
1197 if (fAttrActivations[direction * 3 + 2] == "Relu") {
1198 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1199 out << SP << SP << SP << "if (" << OpName << "_new_cell_state[i] < 0.)\n";
1200 out << SP << SP << SP << SP << OpName << "_new_cell_state[i] = 0.;\n";
1201 out << SP << SP << "}\n";
1202 } else if (fAttrActivations[direction * 3 + 2] == "Tanh") {
1203 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1204 if (fType == "float") {
1205 out << SP << SP << SP << "float ex = exp(-2 * " << OpName << "_new_cell_state[i]);\n";
1206 }
1207 out << SP << SP << SP << SP << OpName << "_new_cell_state[i] = (1. - ex) / (1. + ex);\n";
1208 out << SP << SP << "}\n";
1209 } else if (fAttrActivations[direction * 3 + 2] == "Sigmoid") {
1210 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1211 out << SP << SP << SP << SP << OpName << "_new_cell_state[i] = 1. / (1. + exp(-" << OpName
1212 << "_new_cell_state[i]));\n";
1213 out << SP << SP << "}\n";
1214 } else if (fAttrActivations[direction * 3 + 2] == "Affine") {
1215 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1216 out << SP << SP << SP << SP << OpName << "_new_cell_state[i] = "
1217 << fAttrActivationAlpha[direction * 3 + 2] << " * " << OpName << "_new_cell_state[i] + "
1218 << fAttrActivationBeta[direction * 3 + 2] << ";\n";
1219 out << SP << SP << "}\n";
1220 } else if (fAttrActivations[direction * 3 + 2] == "ScaledTanh") {
1221 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1222 if (fType == "float") {
1223 out << SP << SP << SP << "float ex = exp(-2 * " << fAttrActivationBeta[direction * 3 + 2]
1224 << " * "<< OpName << "_new_cell_state[i]);\n";
1225 }
1226 out << SP << SP << SP << SP << OpName << "_new_cell_state[i] = "
1227 << fAttrActivationAlpha[direction * 3 + 2] << " * (1. - ex) / (1. + ex);\n";
1228 out << SP << SP << "}\n";
1229 } else if (fAttrActivations[direction * 3 + 2] == "HardSigmoid") {
1230 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1231 if (fType == "float") {
1232 out << SP << SP << SP << "float a = " << fAttrActivationAlpha[direction * 3 + 2] << " * "
1233 << OpName << "_new_cell_state[i] + " << fAttrActivationBeta[direction * 3 + 2] << ";\n";
1234 out << SP << SP << SP << "float b = (a > 0.) ? a : 0.;\n";
1235 }
1236 out << SP << SP << SP << SP << OpName << "_new_cell_state[i] = (b < 1.) ? b : 1.;\n";
1237 out << SP << SP << "}\n";
1238 } else if (fAttrActivations[direction * 3 + 2] == "LeakyRelu") {
1239 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1240 out << SP << SP << SP << "if (" << OpName << "_new_cell_state[i] < 0.)\n";
1241 out << SP << SP << SP << SP << OpName << "_new_cell_state[i] = "
1242 << fAttrActivationAlpha[direction * 3 + 2] << " * " << OpName << "_new_cell_state[i];\n";
1243 out << SP << SP << "}\n";
1244 } else if (fAttrActivations[direction * 3 + 2] == "ThresholdRelu") {
1245 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1246 out << SP << SP << SP << "if (" << OpName << "_new_cell_state[i] < "
1247 << fAttrActivationAlpha[direction * 3 + 2] << ")\n";
1248 out << SP << SP << SP << SP << OpName << "_new_cell_state[i] = 0.;\n";
1249 out << SP << SP << "}";
1250 } else if (fAttrActivations[direction * 3 + 2] == "Elu") {
1251 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1252 out << SP << SP << SP << "if (" << OpName << "_new_cell_state[i] < 0.)\n";
1253 out << SP << SP << SP << SP << OpName << "_new_cell_state[i] = "
1254 << fAttrActivationAlpha[direction * 3 + 2] << " * exp(" << OpName << "_new_cell_state[i] - 1.);\n";
1255 out << SP << SP << "}\n";
1256 } else if (fAttrActivations[direction * 3 + 2] == "Softsign") {
1257 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1258 out << SP << SP << SP << SP << OpName << "_new_cell_state[i] = " << OpName
1259 << "_new_cell_state[i] / (1. + abs(" << OpName << "_new_cell_state[i]));\n";
1260 out << SP << SP << "}\n";
1261 } else {
1262 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1263 out << SP << SP << SP << SP << OpName << "_new_cell_state[i] = log(1. + exp("
1264 << OpName << "_new_cell_state[i]));\n";
1265 out << SP << SP << "}\n";
1266 }
1267
1268
1269 out << SP << SP << "for (size_t i = offset; i < offset + " << size << "; i++) {\n";
1270 out << SP << SP << SP << OpName << "_hidden_state[i] = " << OpName << "_output_gate[i] * "
1271 << OpName << "_new_cell_state[i];\n";
1272 out << SP << SP << "}\n";
1273 out << SP << "}\n";
1274 }
1275
1276
1277 if (!fNSequence_lens.empty()) {
1278 out << SP << "for (size_t seq = 0; seq < " << seq_length << "; seq++) {\n";
1279 out << SP << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
1280 out << SP << SP << SP << "if (seq >= tensor_" << fNSequence_lens << "[batch]) {\n";
1281 for (size_t direction = 0; direction < num_directions; direction++) {
1282 out << SP << SP << SP << SP << SP << "for (size_t h = 0; h < " << fAttrHiddenSize << "; h++) {\n";
1283 out << SP << SP << SP << SP << SP << SP << "size_t idx = seq * "
1284 << num_directions * batch_size * fAttrHiddenSize + direction * batch_size * fAttrHiddenSize
1285 << " + batch * " << fAttrHiddenSize << " + h;\n";
1286 out << SP << SP << SP << SP << SP << SP << OpName << "_cell_state[idx] = 0.;\n";
1287 out << SP << SP << SP << SP << SP << SP << OpName << "_hidden_state[idx] = 0.;\n";
1288 out << SP << SP << SP << SP << SP << "}\n";
1289 }
1290 out << SP << SP << SP << "}\n";
1291 out << SP << SP << "}\n";
1292 out << SP << "}\n";
1293 }
1294
1295
1296 if (fAttrLayout == 0) {
1297 if (!fNY_h.empty()) {
1298
1299 if (fNSequence_lens.empty()) {
1300 size_t y_h_size = batch_size * fAttrHiddenSize;
1301 if (fAttrDirection == "backward") {
1302 out << SP << "std::copy(" << OpName << "_hidden_state, " << OpName << "_hidden_state + "
1303 << y_h_size << ", tensor_" << fNY_h << ");\n";
1304 } else {
1305 size_t offset = (seq_length - 1) * num_directions * batch_size * fAttrHiddenSize;
1306 out << SP << "std::copy(" << OpName << "_hidden_state + " << offset << ", " << OpName
1307 << "_hidden_state + " << offset << " + " << y_h_size << ", tensor_" << fNY_h << ");\n";
1308 }
1309 if (num_directions == 2) {
1310 out << SP << "std::copy(" << OpName << "_hidden_state + " << y_h_size << ", " << OpName
1311 << "_hidden_state + " << 2 * y_h_size << ", tensor_" << fNY_h << " + " << y_h_size << ");\n";
1312 }
1313 } else {
1314 if (fAttrDirection == "backward") {
1315 out << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
1316 out << SP << SP << "size_t offset = batch * " << fAttrHiddenSize << ";\n";
1317 out << SP << SP << "std::copy(" << OpName << "_hidden_state + offset, " << OpName
1318 << "_hidden_state + offset + " << fAttrHiddenSize << ", tensor_" << fNY_h << " + offset);\n";
1319 out << SP << "}\n";
1320 } else {
1321 out << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
1322 out << SP << SP << "size_t seq = " << "tensor_" << fNSequence_lens << "[batch] - 1;\n";
1323 out << SP << SP << "size_t offset = seq * " << num_directions * batch_size * fAttrHiddenSize
1324 << " + batch * " << fAttrHiddenSize << ";\n";
1325 out << SP << SP << "size_t y_h_offset = batch * " << fAttrHiddenSize << ";\n";
1326 out << SP << SP << "std::copy(" << OpName << "_hidden_state + offset, " << OpName
1327 << "_hidden_state + offset + " << fAttrHiddenSize << ", tensor_" << fNY_h << " + y_h_offset);\n";
1328 out << SP << "}\n";
1329 }
1330 if (num_directions == 2) {
1331 out << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
1332 out << SP << SP << "size_t offset = " << batch_size * fAttrHiddenSize
1333 << " + batch * " << fAttrHiddenSize << ";\n";
1334 out << SP << SP << "size_t y_h_offset = " << batch_size * fAttrHiddenSize
1335 << " + batch * " << fAttrHiddenSize << ";\n";
1336 out << SP << SP << "std::copy(" << OpName << "_hidden_state + offset, " << OpName
1337 << "_hidden_state + offset + " << fAttrHiddenSize << ", tensor_" << fNY_h << " + y_h_offset);\n";
1338 out << SP << "}\n";
1339 }
1340 }
1341 }
1342 if (!fNY_c.empty()) {
1343
1344 if (fNSequence_lens.empty()) {
1345 size_t y_h_size = batch_size * fAttrHiddenSize;
1346 if (fAttrDirection == "backward") {
1347 out << SP << "std::copy(" << OpName << "_cell_state, " << OpName << "_hidden_state + "
1348 << y_h_size << ", tensor_" << fNY_c << ");\n";
1349 } else {
1350 size_t offset = (seq_length - 1) * num_directions * batch_size * fAttrHiddenSize;
1351 out << SP << "std::copy(" << OpName << "_cell_state + " << offset << ", " << OpName
1352 << "_cell_state + " << offset << " + " << y_h_size << ", tensor_" << fNY_c << ");\n";
1353 }
1354 if (num_directions == 2) {
1355 out << SP << "std::copy(" << OpName << "_cell_state + " << y_h_size << ", " << OpName
1356 << "_cell_state + " << 2 * y_h_size << ", tensor_" << fNY_c << " + " << y_h_size << ");\n";
1357 }
1358 } else {
1359 if (fAttrDirection == "backward") {
1360 out << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
1361 out << SP << SP << "size_t offset = batch * " << fAttrHiddenSize << ";\n";
1362 out << SP << SP << "std::copy(" << OpName << "_cell_state + offset, " << OpName
1363 << "_cell_state + offset + " << fAttrHiddenSize << ", tensor_" << fNY_c << " + offset);\n";
1364 out << SP << "}\n";
1365 } else {
1366 out << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
1367 out << SP << SP << "size_t seq = " << "tensor_" << fNSequence_lens << "[batch] - 1;\n";
1368 out << SP << SP << "size_t offset = seq * " << num_directions * batch_size * fAttrHiddenSize
1369 << " + batch * " << fAttrHiddenSize << ";\n";
1370 out << SP << SP << "size_t y_h_offset = batch * " << fAttrHiddenSize << ";\n";
1371 out << SP << SP << "std::copy(" << OpName << "_cell_state + offset, " << OpName
1372 << "_cell_state + offset + " << fAttrHiddenSize << ", tensor_" << fNY_c << " + y_h_offset);\n";
1373 out << SP << "}\n";
1374 }
1375 if (num_directions == 2) {
1376 out << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
1377 out << SP << SP << "size_t offset = " << batch_size * fAttrHiddenSize
1378 << " + batch * " << fAttrHiddenSize << ";\n";
1379 out << SP << SP << "size_t y_h_offset = " << batch_size * fAttrHiddenSize
1380 << " + batch * " << fAttrHiddenSize << ";\n";
1381 out << SP << SP << "std::copy(" << OpName << "_cell_state + offset, " << OpName
1382 << "_cell_state + offset + " << fAttrHiddenSize << ", tensor_" << fNY_c << " + y_h_offset);\n";
1383 out << SP << "}\n";
1384 }
1385 }
1386 }
1387 } else {
1388 if (!fNY.empty()) {
1389
1390 for (size_t direction = 0; direction < num_directions; direction++) {
1391 out << SP << "for (size_t seq = 0; seq < " << seq_length << "; seq++) {\n";
1392 out << SP << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
1393 out << SP << SP << SP << "size_t offset = seq * " << num_directions * batch_size * fAttrHiddenSize
1394 << " + " << direction * batch_size * fAttrHiddenSize << " + batch * " << fAttrHiddenSize << ";\n";
1395 out << SP << SP << SP << "size_t y_offset = batch * " << seq_length * num_directions * fAttrHiddenSize
1396 << " + seq * " << num_directions * fAttrHiddenSize << " + " << direction * fAttrHiddenSize << ";\n";
1397 out << SP << SP << SP << "std::copy(" << OpName << "_hidden_state + offset, " << OpName
1398 << "_hidden_state + offset + " << fAttrHiddenSize << ", tensor_" << fNY << " + y_offset);\n";
1399 out << SP << SP << "}\n";
1400 out << SP << "}\n";
1401 }
1402 }
1403 if (!fNY_h.empty()) {
1404
1405 if (fAttrDirection == "backward") {
1406 out << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
1407 out << SP << SP << "size_t offset = batch * " << fAttrHiddenSize << ";\n";
1408 out << SP << SP << "size_t y_h_offset = batch * " << num_directions * fAttrHiddenSize << ";\n";
1409 out << SP << SP << "std::copy(" << OpName << "_hidden_state + offset, " << OpName
1410 << "_hidden_state + offset + " << fAttrHiddenSize << ", tensor_" << fNY_h << " + y_h_offset);\n";
1411 out << SP << "}\n";
1412 } else {
1413 out << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
1414 if (fNSequence_lens.empty()) {
1415 out << SP << SP << "size_t seq = " << seq_length - 1 << ";\n";
1416 } else {
1417 out << SP << SP << "size_t seq = " << "tensor_" << fNSequence_lens << "[batch] - 1;\n";
1418 }
1419 out << SP << SP << "size_t offset = seq * " << num_directions * batch_size * fAttrHiddenSize
1420 << " + batch * " << fAttrHiddenSize << ";\n";
1421 out << SP << SP << "size_t y_h_offset = batch * " << num_directions * fAttrHiddenSize << ";\n";
1422 out << SP << SP << "std::copy(" << OpName << "_hidden_state + offset, " << OpName
1423 << "_hidden_state + offset + " << fAttrHiddenSize << ", tensor_" << fNY_h << " + y_h_offset);\n";
1424 out << SP << "}\n";
1425 }
1426 if (num_directions == 2) {
1427 out << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
1428 out << SP << SP << "size_t offset = " << batch_size * fAttrHiddenSize << " + batch * "
1429 << fAttrHiddenSize << ";\n";
1430 out << SP << SP << "size_t y_h_offset = batch * " << num_directions * fAttrHiddenSize << " + "
1431 << fAttrHiddenSize << ";\n";
1432 out << SP << SP << "std::copy(" << OpName << "_hidden_state + offset, " << OpName
1433 << "_hidden_state + offset + " << fAttrHiddenSize << ", tensor_" << fNY_h << " + y_h_offset);\n";
1434 out << SP << "}\n";
1435 }
1436 }
1437
1438 if (!fNY_c.empty()) {
1439
1440 if (fAttrDirection == "backward") {
1441 out << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
1442 out << SP << SP << "size_t offset = batch * " << fAttrHiddenSize << ";\n";
1443 out << SP << SP << "size_t y_h_offset = batch * " << num_directions * fAttrHiddenSize << ";\n";
1444 out << SP << SP << "std::copy(" << OpName << "_cell_state + offset, " << OpName
1445 << "_cell_state + offset + " << fAttrHiddenSize << ", tensor_" << fNY_c << " + y_h_offset);\n";
1446 out << SP << "}\n";
1447 } else {
1448 out << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
1449 if (fNSequence_lens.empty()) {
1450 out << SP << SP << "size_t seq = " << seq_length - 1 << ";\n";
1451 } else {
1452 out << SP << SP << "size_t seq = " << "tensor_" << fNSequence_lens << "[batch] - 1;\n";
1453 }
1454 out << SP << SP << "size_t offset = seq * " << num_directions * batch_size * fAttrHiddenSize
1455 << " + batch * " << fAttrHiddenSize << ";\n";
1456 out << SP << SP << "size_t y_h_offset = batch * " << num_directions * fAttrHiddenSize << ";\n";
1457 out << SP << SP << "std::copy(" << OpName << "_cell_state + offset, " << OpName
1458 << "_cell_state + offset + " << fAttrHiddenSize << ", tensor_" << fNY_c << " + y_h_offset);\n";
1459 out << SP << "}\n";
1460 }
1461 if (num_directions == 2) {
1462 out << SP << "for (size_t batch = 0; batch < " << batch_size << "; batch++) {\n";
1463 out << SP << SP << "size_t offset = " << batch_size * fAttrHiddenSize << " + batch * "
1464 << fAttrHiddenSize << ";\n";
1465 out << SP << SP << "size_t y_h_offset = batch * " << num_directions * fAttrHiddenSize << " + "
1466 << fAttrHiddenSize << ";\n";
1467 out << SP << SP << "std::copy(" << OpName << "_cell_state + offset, " << OpName
1468 << "_cell_state + offset + " << fAttrHiddenSize << ", tensor_" << fNY_c << " + y_h_offset);\n";
1469 out << SP << "}\n";
1470 }
1471 }
1472 }
1473
1474 return out.str();
1475 }
1476
1477 }
1478 }
1479 }
1480
1481 #endif