File indexing completed on 2025-01-30 10:23:00
0001 #ifndef TMVA_SOFIE_ROPERATOR_LSTM
0002 #define TMVA_SOFIE_ROPERATOR_LSTM
0003
0004 #include "TMVA/RModel.hxx"
0005 #include "TMVA/ROperator.hxx"
0006 #include "TMVA/SOFIE_common.hxx"
0007
0008 #include <memory>
0009 #include <sstream>
0010 #include <string>
0011 #include <vector>
0012
0013 namespace TMVA {
0014 namespace Experimental {
0015 namespace SOFIE {
0016
0017
0018
0019
0020
0021
0022
0023 template <typename T> class ROperator_LSTM final : public ROperator {
0024 private:
0025 std::vector<float> fAttrActivationAlpha;
0026 std::vector<float> fAttrActivationBeta;
0027 std::vector<std::string> fAttrActivations;
0028 float fAttrClip;
0029 std::string fAttrDirection;
0030 size_t fAttrHiddenSize;
0031 size_t fAttrInputForget;
0032 size_t fAttrLayout;
0033
0034 std::string fNX;
0035 std::string fNW;
0036 std::string fNR;
0037 std::string fNB;
0038 std::string fNSequence_lens;
0039 std::string fNInitial_h;
0040 std::string fNInitial_c;
0041 std::string fNP;
0042 std::string fNY;
0043 std::string fNY_h;
0044 std::string fNY_c;
0045
0046 std::vector<size_t> fShapeX;
0047 std::vector<size_t> fShapeW;
0048 std::vector<size_t> fShapeR;
0049 std::vector<size_t> fShapeB;
0050 std::vector<size_t> fShapeSequence_lens;
0051 std::vector<size_t> fShapeInitial_h;
0052 std::vector<size_t> fShapeInitial_c;
0053 std::vector<size_t> fShapeP;
0054 std::vector<size_t> fShapeY;
0055 std::vector<size_t> fShapeY_h;
0056 std::vector<size_t> fShapeY_c;
0057
0058 std::string fType;
0059
0060 public:
0061
0062 ROperator_LSTM() {}
0063
0064
0065
0066
0067
0068
0069
0070
0071
0072
0073
0074
0075
0076
0077
0078
0079
0080
0081
0082
0083
0084
0085
0086 ROperator_LSTM(std::vector<float> activation_alpha,
0087 std::vector<float> activation_beta,
0088 std::vector<std::string> activations, float clip,
0089 std::string direction, size_t hidden_size,
0090 size_t input_forget, size_t layout,
0091 std::string nameX, std::string nameW, std::string nameR,
0092 std::string nameB, std::string nameSequence_lens,
0093 std::string nameInitial_h, std::string nameInitial_c, std::string nameP,
0094 std::string nameY, std::string nameY_h, std::string nameY_c)
0095 : fAttrActivationAlpha(activation_alpha),
0096 fAttrActivationBeta(activation_beta), fAttrActivations(activations),
0097 fAttrClip(clip), fAttrDirection(direction), fAttrHiddenSize(hidden_size),
0098 fAttrInputForget(input_forget), fAttrLayout(layout),
0099 fNX(UTILITY::Clean_name(nameX)), fNW(UTILITY::Clean_name(nameW)),
0100 fNR(UTILITY::Clean_name(nameR)), fNB(UTILITY::Clean_name(nameB)),
0101 fNSequence_lens(UTILITY::Clean_name(nameSequence_lens)),
0102 fNInitial_h(UTILITY::Clean_name(nameInitial_h)),
0103 fNInitial_c(UTILITY::Clean_name(nameInitial_c)), fNP(UTILITY::Clean_name(nameP)),
0104 fNY(UTILITY::Clean_name(nameY)), fNY_h(UTILITY::Clean_name(nameY_h)),
0105 fNY_c(UTILITY::Clean_name(nameY_c)) {
0106 if (std::is_same<T, float>::value) {
0107 fType = "float";
0108 } else {
0109 throw std::runtime_error(
0110 "TMVA SOFIE Encountered unsupported type parsing a LSTM operator");
0111 }
0112 }
0113
0114
0115
0116
0117
0118 std::vector<ETensorType> TypeInference(std::vector<ETensorType> input);
0119
0120
0121
0122
0123
0124 std::vector<std::vector<size_t>>
0125 ShapeInference(std::vector<std::vector<size_t>> input);
0126
0127
0128
0129
0130
0131 void Initialize(RModel &model);
0132
0133
0134
0135
0136
0137 std::string Generate(std::string OpName);
0138
0139
0140
0141
0142
0143 std::string GenerateSessionMembersCode(std::string opName);
0144
0145
0146
0147 std::vector<std::string> GetBlasRoutines() { return { std::string("Gemm"), std::string("Axpy") }; }
0148 };
0149
0150 }
0151 }
0152 }
0153
0154
0155 #include "TMVA/ROperator_LSTM.icc"
0156
0157 #endif