File indexing completed on 2025-09-18 09:32:38
0001 #ifndef TMVA_SOFIE_ROPERATOR_RNN
0002 #define TMVA_SOFIE_ROPERATOR_RNN
0003
0004 #include "TMVA/RModel.hxx"
0005 #include "TMVA/ROperator.hxx"
0006 #include "TMVA/SOFIE_common.hxx"
0007
0008 #include <memory>
0009 #include <sstream>
0010 #include <vector>
0011
0012 namespace TMVA {
0013 namespace Experimental {
0014 namespace SOFIE {
0015
0016
0017
0018
0019
0020
0021
0022 template <typename T> class ROperator_RNN final : public ROperator {
0023 private:
0024 std::vector<float> fAttrActivationAlpha;
0025 std::vector<float> fAttrActivationBeta;
0026 std::vector<std::string> fAttrActivations;
0027 float fAttrClip;
0028 std::string fAttrDirection;
0029 size_t fAttrHiddenSize;
0030 size_t fAttrLayout;
0031
0032 std::string fNX;
0033 std::string fNW;
0034 std::string fNR;
0035 std::string fNB;
0036 std::string fNSequence_lens;
0037 std::string fNInitial_h;
0038 std::string fNY;
0039 std::string fNY_h;
0040
0041 std::vector<size_t> fShapeX;
0042 std::vector<size_t> fShapeW;
0043 std::vector<size_t> fShapeR;
0044 std::vector<size_t> fShapeB;
0045 std::vector<size_t> fShapeSequence_lens;
0046 std::vector<size_t> fShapeInitial_h;
0047 std::vector<size_t> fShapeY;
0048 std::vector<size_t> fShapeY_h;
0049
0050 std::string fType;
0051
0052 public:
0053
0054 ROperator_RNN() {}
0055
0056
0057
0058
0059
0060
0061
0062
0063
0064
0065
0066
0067
0068
0069
0070
0071
0072
0073
0074 ROperator_RNN(std::vector<float> activation_alpha,
0075 std::vector<float> activation_beta,
0076 std::vector<std::string> activations, float clip,
0077 std::string direction, size_t hidden_size, size_t layout,
0078 std::string nameX, std::string nameW, std::string nameR,
0079 std::string nameB, std::string nameSequence_lens,
0080 std::string nameInitial_h, std::string nameY,
0081 std::string nameY_h)
0082 : fAttrActivationAlpha(activation_alpha),
0083 fAttrActivationBeta(activation_beta), fAttrActivations(activations),
0084 fAttrClip(clip), fAttrDirection(direction),
0085 fAttrHiddenSize(hidden_size), fAttrLayout(layout),
0086 fNX(UTILITY::Clean_name(nameX)), fNW(UTILITY::Clean_name(nameW)),
0087 fNR(UTILITY::Clean_name(nameR)), fNB(UTILITY::Clean_name(nameB)),
0088 fNSequence_lens(UTILITY::Clean_name(nameSequence_lens)),
0089 fNInitial_h(UTILITY::Clean_name(nameInitial_h)),
0090 fNY(UTILITY::Clean_name(nameY)), fNY_h(UTILITY::Clean_name(nameY_h)) {
0091 if (std::is_same<T, float>::value) {
0092 fType = "float";
0093 } else {
0094 throw std::runtime_error(
0095 "TMVA SOFIE Encountered unsupported type parsing a RNN operator");
0096 }
0097
0098 fInputTensorNames = { fNX, fNW, fNR };
0099 if(!fNB.empty()){
0100 fInputTensorNames.emplace_back(fNB);
0101 }
0102 if(!fNSequence_lens.empty()){
0103 fInputTensorNames.emplace_back(fNSequence_lens);
0104 }
0105 if(!fNInitial_h.empty()){
0106 fInputTensorNames.emplace_back(fNInitial_h);
0107 }
0108
0109 fOutputTensorNames = { };
0110 if(!fNY.empty()){
0111 fOutputTensorNames.emplace_back(fNY);
0112 }
0113 if(!fNY_h.empty()){
0114 fOutputTensorNames.emplace_back(fNY_h);
0115 }
0116 }
0117
0118
0119
0120
0121
0122 std::vector<ETensorType> TypeInference(std::vector<ETensorType> input);
0123
0124
0125
0126
0127
0128 std::vector<std::vector<size_t>>
0129 ShapeInference(std::vector<std::vector<size_t>> input);
0130
0131
0132
0133
0134
0135 void Initialize(RModel &);
0136
0137
0138
0139
0140
0141 std::string Generate(std::string OpName);
0142
0143
0144 std::string GenerateSessionMembersCode(std::string opName);
0145
0146
0147
0148 std::vector<std::string> GetBlasRoutines() { return { std::string("Gemm"), std::string("Axpy") }; }
0149 };
0150
0151 }
0152 }
0153 }
0154
0155
0156 #include "TMVA/ROperator_RNN.icc"
0157
0158 #endif