File indexing completed on 2025-01-18 10:11:08
0001 #ifndef TMVA_SOFIE_ROPERATOR_RNN
0002 #define TMVA_SOFIE_ROPERATOR_RNN
0003
0004 #include "TMVA/RModel.hxx"
0005 #include "TMVA/ROperator.hxx"
0006 #include "TMVA/SOFIE_common.hxx"
0007
0008 #include <memory>
0009 #include <sstream>
0010 #include <vector>
0011
0012 namespace TMVA {
0013 namespace Experimental {
0014 namespace SOFIE {
0015
0016
0017
0018
0019
0020
0021
0022 template <typename T> class ROperator_RNN final : public ROperator {
0023 private:
0024 std::vector<float> fAttrActivationAlpha;
0025 std::vector<float> fAttrActivationBeta;
0026 std::vector<std::string> fAttrActivations;
0027 float fAttrClip;
0028 std::string fAttrDirection;
0029 size_t fAttrHiddenSize;
0030 size_t fAttrLayout;
0031
0032 std::string fNX;
0033 std::string fNW;
0034 std::string fNR;
0035 std::string fNB;
0036 std::string fNSequence_lens;
0037 std::string fNInitial_h;
0038 std::string fNY;
0039 std::string fNY_h;
0040
0041 std::vector<size_t> fShapeX;
0042 std::vector<size_t> fShapeW;
0043 std::vector<size_t> fShapeR;
0044 std::vector<size_t> fShapeB;
0045 std::vector<size_t> fShapeSequence_lens;
0046 std::vector<size_t> fShapeInitial_h;
0047 std::vector<size_t> fShapeY;
0048 std::vector<size_t> fShapeY_h;
0049
0050 std::string fType;
0051
0052 public:
0053
0054 ROperator_RNN() {}
0055
0056
0057
0058
0059
0060
0061
0062
0063
0064
0065
0066
0067
0068
0069
0070
0071
0072
0073
0074 ROperator_RNN(std::vector<float> activation_alpha,
0075 std::vector<float> activation_beta,
0076 std::vector<std::string> activations, float clip,
0077 std::string direction, size_t hidden_size, size_t layout,
0078 std::string nameX, std::string nameW, std::string nameR,
0079 std::string nameB, std::string nameSequence_lens,
0080 std::string nameInitial_h, std::string nameY,
0081 std::string nameY_h)
0082 : fAttrActivationAlpha(activation_alpha),
0083 fAttrActivationBeta(activation_beta), fAttrActivations(activations),
0084 fAttrClip(clip), fAttrDirection(direction),
0085 fAttrHiddenSize(hidden_size), fAttrLayout(layout),
0086 fNX(UTILITY::Clean_name(nameX)), fNW(UTILITY::Clean_name(nameW)),
0087 fNR(UTILITY::Clean_name(nameR)), fNB(UTILITY::Clean_name(nameB)),
0088 fNSequence_lens(UTILITY::Clean_name(nameSequence_lens)),
0089 fNInitial_h(UTILITY::Clean_name(nameInitial_h)),
0090 fNY(UTILITY::Clean_name(nameY)), fNY_h(UTILITY::Clean_name(nameY_h)) {
0091 if (std::is_same<T, float>::value) {
0092 fType = "float";
0093 } else {
0094 throw std::runtime_error(
0095 "TMVA SOFIE Encountered unsupported type parsing a RNN operator");
0096 }
0097 }
0098
0099
0100
0101
0102
0103 std::vector<ETensorType> TypeInference(std::vector<ETensorType> input);
0104
0105
0106
0107
0108
0109 std::vector<std::vector<size_t>>
0110 ShapeInference(std::vector<std::vector<size_t>> input);
0111
0112
0113
0114
0115
0116 void Initialize(RModel &model);
0117
0118
0119
0120
0121
0122 std::string Generate(std::string OpName);
0123
0124
0125 std::string GenerateSessionMembersCode(std::string opName);
0126
0127
0128
0129 std::vector<std::string> GetBlasRoutines() { return { std::string("Gemm"), std::string("Axpy") }; }
0130 };
0131
0132 }
0133 }
0134 }
0135
0136
0137 #include "TMVA/ROperator_RNN.icc"
0138
0139 #endif