Back to home page

EIC code displayed by LXR

 
 

    


File indexing completed on 2025-01-18 10:11:08

0001 #ifndef TMVA_SOFIE_ROPERATOR_RNN
0002 #define TMVA_SOFIE_ROPERATOR_RNN
0003 
0004 #include "TMVA/RModel.hxx"
0005 #include "TMVA/ROperator.hxx"
0006 #include "TMVA/SOFIE_common.hxx"
0007 
0008 #include <memory>
0009 #include <sstream>
0010 #include <vector>
0011 
0012 namespace TMVA {
0013 namespace Experimental {
0014 namespace SOFIE {
0015 
0016 /*! \brief Recurrent Neural Network operator
0017  *
0018  * Inference code generation for one-layer vanilla RNN. Supports forward, reverse and bidirectional RNNs.
0019  * See the <a href="https://github.com/onnx/onnx/blob/master/docs/Operators.md#RNN">ONNX documentation</a>
0020  * for details about the supported RNN architectures.
0021  */
0022 template <typename T> class ROperator_RNN final : public ROperator {
0023  private:
0024    std::vector<float> fAttrActivationAlpha;   ///< Scaling values used by some activation functions
0025    std::vector<float> fAttrActivationBeta;    ///< Scaling values used by some activation functions
0026    std::vector<std::string> fAttrActivations; ///< Activation functions
0027    float fAttrClip;                           ///< Clip threshold
0028    std::string fAttrDirection;                ///< Direction of processing
0029    size_t fAttrHiddenSize;                    ///< Number of the hidden layers
0030    size_t fAttrLayout;                        ///< Data layout
0031 
0032    std::string fNX;                           ///< Name of the input
0033    std::string fNW;                           ///< Name of the weights
0034    std::string fNR;                           ///< Name of the recurrence
0035    std::string fNB;                           ///< Name of the bias
0036    std::string fNSequence_lens;               ///< Name of the length of the sequences
0037    std::string fNInitial_h;                   ///< Name of the initial value of the hidden states
0038    std::string fNY;                           ///< Name of the output
0039    std::string fNY_h;                         ///< Name of the last sequence of the output
0040 
0041    std::vector<size_t> fShapeX;               ///< Shape of the input
0042    std::vector<size_t> fShapeW;               ///< Shape of the weights
0043    std::vector<size_t> fShapeR;               ///< Shape of the recurrence
0044    std::vector<size_t> fShapeB;               ///< Shape of the bias
0045    std::vector<size_t> fShapeSequence_lens;   ///< Shape of the length of the sequences
0046    std::vector<size_t> fShapeInitial_h;       ///< Shape of the initial value of the hidden states
0047    std::vector<size_t> fShapeY;               ///< Shape of the output
0048    std::vector<size_t> fShapeY_h;             ///< Shape of the last sequence of the output
0049 
0050    std::string fType; ///< Type of the tensors
0051 
0052  public:
0053    /*! Default constructor of ROperator_RNN */
0054    ROperator_RNN() {}
0055 
0056    /*! \brief Constructor of ROperator_RNN from the attributes
0057     *
0058     * \param activation_alpha scaling values used by some activation functions
0059     * \param activation_beta scaling values used by some activation functions
0060     * \param activations activation functions
0061     * \param clip clip threshold
0062     * \param direction direction of processing of the sequneces
0063     * \param hidden_size number of hidden layers
0064     * \param layout data layout
0065     * \param nameX name of the input tensor
0066     * \param nameW name of the weight tensor
0067     * \param nameR name of the recurrence tensor
0068     * \param nameB name of the bias tensor
0069     * \param nameSequence_lens name of the length of the sequences
0070     * \param nameInitial_h name of the initial value of the hidden states
0071     * \param nameY name of the output
0072     * \param nameY_h name of the last sequence of the output
0073     */
0074    ROperator_RNN(std::vector<float> activation_alpha,
0075                  std::vector<float> activation_beta,
0076                  std::vector<std::string> activations, float clip,
0077                  std::string direction, size_t hidden_size, size_t layout,
0078                  std::string nameX, std::string nameW, std::string nameR,
0079                  std::string nameB, std::string nameSequence_lens,
0080                  std::string nameInitial_h, std::string nameY,
0081                  std::string nameY_h)
0082        : fAttrActivationAlpha(activation_alpha),
0083          fAttrActivationBeta(activation_beta), fAttrActivations(activations),
0084          fAttrClip(clip), fAttrDirection(direction),
0085          fAttrHiddenSize(hidden_size), fAttrLayout(layout),
0086          fNX(UTILITY::Clean_name(nameX)), fNW(UTILITY::Clean_name(nameW)),
0087          fNR(UTILITY::Clean_name(nameR)), fNB(UTILITY::Clean_name(nameB)),
0088          fNSequence_lens(UTILITY::Clean_name(nameSequence_lens)),
0089          fNInitial_h(UTILITY::Clean_name(nameInitial_h)),
0090          fNY(UTILITY::Clean_name(nameY)), fNY_h(UTILITY::Clean_name(nameY_h)) {
0091       if (std::is_same<T, float>::value) {
0092          fType = "float";
0093       } else {
0094          throw std::runtime_error(
0095              "TMVA SOFIE Encountered unsupported type parsing a RNN operator");
0096       }
0097    }
0098 
0099    /*! \brief Infers the type of the output tensors
0100     *
0101     * \param input type of the input tensors
0102     */
0103    std::vector<ETensorType> TypeInference(std::vector<ETensorType> input);
0104 
0105    /*! \brief Infers the shape of the output tensors
0106     *
0107     * \param input shape of the input tensors
0108     */
0109    std::vector<std::vector<size_t>>
0110    ShapeInference(std::vector<std::vector<size_t>> input);
0111 
0112    /*! \brief Initialize the model
0113     *
0114     * \param model Model
0115     */
0116    void Initialize(RModel &model);
0117 
0118    /*! \brief Generates the inference code
0119     *
0120     * \param OpName name of the operator
0121     */
0122    std::string Generate(std::string OpName);
0123 
0124    // generate code for Session data members (e.g. internal vectors)
0125    std::string GenerateSessionMembersCode(std::string opName);
0126 
0127    /*! \brief Returns the blas routines needed to compile the generated code
0128     */
0129    std::vector<std::string> GetBlasRoutines() { return { std::string("Gemm"), std::string("Axpy") }; }
0130 };
0131 
0132 } // namespace SOFIE
0133 } // namespace Experimental
0134 } // namespace TMVA
0135 
0136 // Implementation of the ROperator_RNN class
0137 #include "TMVA/ROperator_RNN.icc"
0138 
0139 #endif