Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-09-18 09:32:38

0001 #ifndef TMVA_SOFIE_ROPERATOR_RNN
0002 #define TMVA_SOFIE_ROPERATOR_RNN
0003 
0004 #include "TMVA/RModel.hxx"
0005 #include "TMVA/ROperator.hxx"
0006 #include "TMVA/SOFIE_common.hxx"
0007 
0008 #include <memory>
0009 #include <sstream>
0010 #include <vector>
0011 
0012 namespace TMVA {
0013 namespace Experimental {
0014 namespace SOFIE {
0015 
0016 /*! \brief Recurrent Neural Network operator
0017  *
0018  * Inference code generation for one-layer vanilla RNN. Supports forward, reverse and bidirectional RNNs.
0019  * See the <a href="https://github.com/onnx/onnx/blob/master/docs/Operators.md#RNN">ONNX documentation</a>
0020  * for details about the supported RNN architectures.
0021  */
0022 template <typename T> class ROperator_RNN final : public ROperator {
0023  private:
0024    std::vector<float> fAttrActivationAlpha;   ///< Scaling values used by some activation functions
0025    std::vector<float> fAttrActivationBeta;    ///< Scaling values used by some activation functions
0026    std::vector<std::string> fAttrActivations; ///< Activation functions
0027    float fAttrClip;                           ///< Clip threshold
0028    std::string fAttrDirection;                ///< Direction of processing
0029    size_t fAttrHiddenSize;                    ///< Number of the hidden layers
0030    size_t fAttrLayout;                        ///< Data layout
0031 
0032    std::string fNX;                           ///< Name of the input
0033    std::string fNW;                           ///< Name of the weights
0034    std::string fNR;                           ///< Name of the recurrence
0035    std::string fNB;                           ///< Name of the bias
0036    std::string fNSequence_lens;               ///< Name of the length of the sequences
0037    std::string fNInitial_h;                   ///< Name of the initial value of the hidden states
0038    std::string fNY;                           ///< Name of the output
0039    std::string fNY_h;                         ///< Name of the last sequence of the output
0040 
0041    std::vector<size_t> fShapeX;               ///< Shape of the input
0042    std::vector<size_t> fShapeW;               ///< Shape of the weights
0043    std::vector<size_t> fShapeR;               ///< Shape of the recurrence
0044    std::vector<size_t> fShapeB;               ///< Shape of the bias
0045    std::vector<size_t> fShapeSequence_lens;   ///< Shape of the length of the sequences
0046    std::vector<size_t> fShapeInitial_h;       ///< Shape of the initial value of the hidden states
0047    std::vector<size_t> fShapeY;               ///< Shape of the output
0048    std::vector<size_t> fShapeY_h;             ///< Shape of the last sequence of the output
0049 
0050    std::string fType; ///< Type of the tensors
0051 
0052  public:
0053    /*! Default constructor of ROperator_RNN */
0054    ROperator_RNN() {}
0055 
0056    /*! \brief Constructor of ROperator_RNN from the attributes
0057     *
0058     * \param activation_alpha scaling values used by some activation functions
0059     * \param activation_beta scaling values used by some activation functions
0060     * \param activations activation functions
0061     * \param clip clip threshold
0062     * \param direction direction of processing of the sequneces
0063     * \param hidden_size number of hidden layers
0064     * \param layout data layout
0065     * \param nameX name of the input tensor
0066     * \param nameW name of the weight tensor
0067     * \param nameR name of the recurrence tensor
0068     * \param nameB name of the bias tensor
0069     * \param nameSequence_lens name of the length of the sequences
0070     * \param nameInitial_h name of the initial value of the hidden states
0071     * \param nameY name of the output
0072     * \param nameY_h name of the last sequence of the output
0073     */
0074    ROperator_RNN(std::vector<float> activation_alpha,
0075                  std::vector<float> activation_beta,
0076                  std::vector<std::string> activations, float clip,
0077                  std::string direction, size_t hidden_size, size_t layout,
0078                  std::string nameX, std::string nameW, std::string nameR,
0079                  std::string nameB, std::string nameSequence_lens,
0080                  std::string nameInitial_h, std::string nameY,
0081                  std::string nameY_h)
0082        : fAttrActivationAlpha(activation_alpha),
0083          fAttrActivationBeta(activation_beta), fAttrActivations(activations),
0084          fAttrClip(clip), fAttrDirection(direction),
0085          fAttrHiddenSize(hidden_size), fAttrLayout(layout),
0086          fNX(UTILITY::Clean_name(nameX)), fNW(UTILITY::Clean_name(nameW)),
0087          fNR(UTILITY::Clean_name(nameR)), fNB(UTILITY::Clean_name(nameB)),
0088          fNSequence_lens(UTILITY::Clean_name(nameSequence_lens)),
0089          fNInitial_h(UTILITY::Clean_name(nameInitial_h)),
0090          fNY(UTILITY::Clean_name(nameY)), fNY_h(UTILITY::Clean_name(nameY_h)) {
0091       if (std::is_same<T, float>::value) {
0092          fType = "float";
0093       } else {
0094          throw std::runtime_error(
0095              "TMVA SOFIE Encountered unsupported type parsing a RNN operator");
0096       }
0097 
0098       fInputTensorNames = { fNX, fNW, fNR };
0099       if(!fNB.empty()){
0100          fInputTensorNames.emplace_back(fNB);
0101       }
0102       if(!fNSequence_lens.empty()){
0103          fInputTensorNames.emplace_back(fNSequence_lens);
0104       }
0105       if(!fNInitial_h.empty()){
0106          fInputTensorNames.emplace_back(fNInitial_h);
0107       }
0108 
0109       fOutputTensorNames = { };
0110       if(!fNY.empty()){
0111          fOutputTensorNames.emplace_back(fNY);
0112       }
0113       if(!fNY_h.empty()){
0114          fOutputTensorNames.emplace_back(fNY_h);
0115       }
0116    }
0117 
0118    /*! \brief Infers the type of the output tensors
0119     *
0120     * \param input type of the input tensors
0121     */
0122    std::vector<ETensorType> TypeInference(std::vector<ETensorType> input);
0123 
0124    /*! \brief Infers the shape of the output tensors
0125     *
0126     * \param input shape of the input tensors
0127     */
0128    std::vector<std::vector<size_t>>
0129    ShapeInference(std::vector<std::vector<size_t>> input);
0130 
0131    /*! \brief Initialize the model
0132     *
0133     * \param model Model
0134     */
0135    void Initialize(RModel &);
0136 
0137    /*! \brief Generates the inference code
0138     *
0139     * \param OpName name of the operator
0140     */
0141    std::string Generate(std::string OpName);
0142 
0143    // generate code for Session data members (e.g. internal vectors)
0144    std::string GenerateSessionMembersCode(std::string opName);
0145 
0146    /*! \brief Returns the blas routines needed to compile the generated code
0147     */
0148    std::vector<std::string> GetBlasRoutines()  { return { std::string("Gemm"), std::string("Axpy") }; }
0149 };
0150 
0151 } // namespace SOFIE
0152 } // namespace Experimental
0153 } // namespace TMVA
0154 
0155 // Implementation of the ROperator_RNN class
0156 #include "TMVA/ROperator_RNN.icc"
0157 
0158 #endif