File indexing completed on 2025-12-06 10:01:47
0001 #ifndef TMVA_SOFIE_ROPERATOR_LSTM
0002 #define TMVA_SOFIE_ROPERATOR_LSTM
0003
0004 #include "TMVA/RModel.hxx"
0005 #include "TMVA/ROperator.hxx"
0006 #include "TMVA/SOFIE_common.hxx"
0007
0008 #include <memory>
0009 #include <sstream>
0010 #include <string>
0011 #include <vector>
0012
0013 namespace TMVA {
0014 namespace Experimental {
0015 namespace SOFIE {
0016
0017
0018
0019
0020
0021
0022
0023 template <typename T> class ROperator_LSTM final : public ROperator {
0024 private:
0025 std::vector<float> fAttrActivationAlpha;
0026 std::vector<float> fAttrActivationBeta;
0027 std::vector<std::string> fAttrActivations;
0028 float fAttrClip;
0029 std::string fAttrDirection;
0030 size_t fAttrHiddenSize;
0031 size_t fAttrInputForget;
0032 size_t fAttrLayout;
0033
0034 std::string fNX;
0035 std::string fNW;
0036 std::string fNR;
0037 std::string fNB;
0038 std::string fNSequence_lens;
0039 std::string fNInitial_h;
0040 std::string fNInitial_c;
0041 std::string fNP;
0042 std::string fNY;
0043 std::string fNY_h;
0044 std::string fNY_c;
0045
0046 std::vector<size_t> fShapeX;
0047 std::vector<size_t> fShapeW;
0048 std::vector<size_t> fShapeR;
0049 std::vector<size_t> fShapeB;
0050 std::vector<size_t> fShapeSequence_lens;
0051 std::vector<size_t> fShapeInitial_h;
0052 std::vector<size_t> fShapeInitial_c;
0053 std::vector<size_t> fShapeP;
0054 std::vector<size_t> fShapeY;
0055 std::vector<size_t> fShapeY_h;
0056 std::vector<size_t> fShapeY_c;
0057
0058 std::string fType;
0059
0060 public:
0061
0062 ROperator_LSTM() {}
0063
0064
0065
0066
0067
0068
0069
0070
0071
0072
0073
0074
0075
0076
0077
0078
0079
0080
0081
0082
0083
0084
0085
0086 ROperator_LSTM(std::vector<float> activation_alpha,
0087 std::vector<float> activation_beta,
0088 std::vector<std::string> activations, float clip,
0089 std::string direction, size_t hidden_size,
0090 size_t input_forget, size_t layout,
0091 std::string nameX, std::string nameW, std::string nameR,
0092 std::string nameB, std::string nameSequence_lens,
0093 std::string nameInitial_h, std::string nameInitial_c, std::string nameP,
0094 std::string nameY, std::string nameY_h, std::string nameY_c)
0095 : fAttrActivationAlpha(activation_alpha),
0096 fAttrActivationBeta(activation_beta), fAttrActivations(activations),
0097 fAttrClip(clip), fAttrDirection(direction), fAttrHiddenSize(hidden_size),
0098 fAttrInputForget(input_forget), fAttrLayout(layout),
0099 fNX(UTILITY::Clean_name(nameX)), fNW(UTILITY::Clean_name(nameW)),
0100 fNR(UTILITY::Clean_name(nameR)), fNB(UTILITY::Clean_name(nameB)),
0101 fNSequence_lens(UTILITY::Clean_name(nameSequence_lens)),
0102 fNInitial_h(UTILITY::Clean_name(nameInitial_h)),
0103 fNInitial_c(UTILITY::Clean_name(nameInitial_c)), fNP(UTILITY::Clean_name(nameP)),
0104 fNY(UTILITY::Clean_name(nameY)), fNY_h(UTILITY::Clean_name(nameY_h)),
0105 fNY_c(UTILITY::Clean_name(nameY_c)) {
0106 if (std::is_same<T, float>::value) {
0107 fType = "float";
0108 } else {
0109 throw std::runtime_error(
0110 "TMVA SOFIE Encountered unsupported type parsing a LSTM operator");
0111 }
0112
0113 fInputTensorNames = { fNX, fNW, fNR };
0114 if (!fNB.empty()){
0115 fInputTensorNames.emplace_back(fNB);
0116 }
0117 if (!fNSequence_lens.empty()){
0118 fInputTensorNames.emplace_back(fNSequence_lens);
0119 }
0120 if (!fNInitial_h.empty()){
0121 fInputTensorNames.emplace_back(fNInitial_h);
0122 }
0123 if (!fNInitial_c.empty()){
0124 fInputTensorNames.emplace_back(fNInitial_c);
0125 }
0126 if (!fNP.empty()){
0127 fInputTensorNames.emplace_back(fNP);
0128 }
0129
0130 fOutputTensorNames = { };
0131 if (!fNY.empty()){
0132 fOutputTensorNames.emplace_back(fNY);
0133 }
0134 if (!fNY_h.empty()){
0135 fOutputTensorNames.emplace_back(fNY_h);
0136 }
0137 if (!fNY_c.empty()){
0138 fOutputTensorNames.emplace_back(fNY_c);
0139 }
0140 }
0141
0142
0143
0144
0145
0146 std::vector<ETensorType> TypeInference(std::vector<ETensorType> input);
0147
0148
0149
0150
0151
0152 std::vector<std::vector<size_t>>
0153 ShapeInference(std::vector<std::vector<size_t>> input);
0154
0155
0156
0157
0158
0159 void Initialize(RModel &);
0160
0161
0162
0163
0164
0165 std::string Generate(std::string OpName);
0166
0167
0168
0169
0170
0171 std::string GenerateSessionMembersCode(std::string opName);
0172
0173
0174
0175 std::vector<std::string> GetBlasRoutines() { return { std::string("Gemm"), std::string("Axpy") }; }
0176 };
0177
0178 }
0179 }
0180 }
0181
0182
0183 #include "TMVA/ROperator_LSTM.icc"
0184
0185 #endif